Commit 0d97cc8c authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
Pipeline #316 failed with stages
in 0 seconds
import logging
import os
import os.path as osp
import pathlib
import time
import json
from functools import partial
import random
import threading
import qt
import ctk
import vtk
import numpy as np
import SimpleITK as sitk
import slicer
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin
# when test, wont use any paddle related funcion
HERE = pathlib.Path(__file__).parent.absolute()
TEST = osp.exists(HERE / "TEST")
if not TEST:
logging.getLogger().setLevel(logging.ERROR)
if not TEST:
try:
import paddle
except ModuleNotFoundError as e:
if slicer.util.confirmOkCancelDisplay(
"This module requires 'paddlepaddle' Python package. Click OK to install it now."
):
slicer.util.pip_install("paddlepaddle")
import paddle
try:
import paddleseg
except ModuleNotFoundError as e:
if slicer.util.confirmOkCancelDisplay(
"This module requires 'paddleseg' Python package. Click OK to install it now."
):
slicer.util.pip_install("paddleseg")
import paddle
import inference
import inference.predictor as predictor
# TODO: get some better color map
colors = [
(0.5019607843137255, 0.6823529411764706, 0.5019607843137255),
(0.9450980392156862, 0.8392156862745098, 0.5686274509803921),
(0.6941176470588235, 0.4784313725490196, 0.3960784313725490),
(0.4352941176470588, 0.7215686274509804, 0.8235294117647058),
(0.8470588235294118, 0.3960784313725490, 0.3098039215686274),
(0.8666666666666667, 0.5098039215686274, 0.3960784313725490),
(0.5647058823529412, 0.9333333333333333, 0.5647058823529412),
(0.7529411764705882, 0.4078431372549019, 0.3450980392156862),
(0.8627450980392157, 0.9607843137254902, 0.0784313725490196),
(0.3058823529411765, 0.2470588235294117, 0.0000000000000000),
(1.0000000000000000, 0.9803921568627451, 0.8627450980392157),
(0.9019607843137255, 0.8627450980392157, 0.2745098039215685),
(0.7843137254901961, 0.7843137254901961, 0.9215686274509803),
(0.9803921568627451, 0.9803921568627451, 0.8235294117647058),
]
#
# EISegMed3D
#
class EISegMed3D(ScriptedLoadableModule):
"""Uses ScriptedLoadableModule base class, available at:
https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
"""
def __init__(self, parent):
ScriptedLoadableModule.__init__(self, parent)
self.parent.title = "EISegMed3D" # TODO: make this more human readable by adding spaces
self.parent.categories = [
"Interactive Segmentation"
] # TODO: set categories (folders where the module shows up in the module selector)
self.parent.dependencies = [
] # TODO: add here list of module names that this module requires
self.parent.contributors = ["Lin Han, Daisy (Baidu Corp.)"]
# TODO: update with short description of the module and a link to online module documentation
self.parent.helpText = """
This is an example of scripted loadable module bundled in an extension.
See more information in <a href="https://github.com/organization/projectname#EISegMed3D">module documentation</a>.
"""
# TODO: replace with organization, grant and thanks
self.parent.acknowledgementText = """
This file was originally developed by Jean-Christophe Fillion-Robin, Kitware Inc., Andras Lasso, PerkLab,
and Steve Pieper, Isomics, Inc. and was partially funded by NIH grant 3P41RR013218-12S1.
"""
# Additional initialization step after application startup is complete
slicer.app.connect("startupCompleted()", self.initializeAfterStartup)
def initializeAfterStartup(self):
# print("initializeAfterStartup", slicer.app.commandOptions().noMainWindow)
pass
class Clicker(object):
def __init__(self):
self.reset_clicks()
def get_clicks(self, clicks_limit=None): # [click1, click2, ...]
return self.clicks_list[:clicks_limit]
def add_click(self, click):
coords = click.coords
click.index = self.num_pos_clicks + self.num_neg_clicks
if click.is_positive:
self.num_pos_clicks += 1
else:
self.num_neg_clicks += 1
self.clicks_list.append(click)
def reset_clicks(self):
self.num_pos_clicks = 0
self.num_neg_clicks = 0
self.clicks_list = []
def __len__(self):
return len(self.clicks_list)
#
# Register sample data sets in Sample Data module
#
def registerSampleData():
"""
Add data sets to Sample Data module.
"""
# It is always recommended to provide sample data for users to make it easy to try the module,
# but if no sample data is available then this method (and associated startupCompeted signal connection) can be removed.
import SampleData
iconsPath = os.path.join(os.path.dirname(__file__), "Resources/Icons")
# To ensure that the source code repository remains small (can be downloaded and installed quickly)
# it is recommended to store data sets that are larger than a few MB in a Github release.
# EISegMed3D1
SampleData.SampleDataLogic.registerCustomSampleDataSource(
# Category and sample name displayed in Sample Data module
category="placePoint",
sampleName="placePoint1",
# Thumbnail should have size of approximately 260x280 pixels and stored in Resources/Icons folder.
# It can be created by Screen Capture module, "Capture all views" option enabled, "Number of images" set to "Single".
thumbnailFileName=os.path.join(iconsPath, "placePoint1.png"),
# Download URL and target file name
uris="https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/998cb522173839c78657f4bc0ea907cea09fd04e44601f17c82ea27927937b95",
fileNames="placePoint1.nrrd",
# Checksum to ensure file integrity. Can be computed by this command:
# import hashlib; print(hashlib.sha256(open(filename, "rb").read()).hexdigest())
checksums="SHA256:998cb522173839c78657f4bc0ea907cea09fd04e44601f17c82ea27927937b95",
# This node name will be used when the data set is loaded
nodeNames="placePoint1", )
# EISegMed3D2
SampleData.SampleDataLogic.registerCustomSampleDataSource(
# Category and sample name displayed in Sample Data module
category="placePoint",
sampleName="placePoint2",
thumbnailFileName=os.path.join(iconsPath, "placePoint2.png"),
# Download URL and target file name
uris="https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/1a64f3f422eb3d1c9b093d1a18da354b13bcf307907c66317e2463ee530b7a97",
fileNames="placePoint2.nrrd",
checksums="SHA256:1a64f3f422eb3d1c9b093d1a18da354b13bcf307907c66317e2463ee530b7a97",
# This node name will be used when the data set is loaded
nodeNames="placePoint2", )
#
# EISegMed3DWidget
#
class EISegMed3DWidget(ScriptedLoadableModuleWidget, VTKObservationMixin):
"""Uses ScriptedLoadableModuleWidget base class, available at:
https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
"""
def __init__(self, parent=None):
"""
Called when the user opens the module the first time and the widget is initialized.
"""
ScriptedLoadableModuleWidget.__init__(self, parent)
VTKObservationMixin.__init__(
self) # needed for parameter node observation
self.logic = None
self._parameterNode = None
# data var
self._dataFolder = None
self._scanPaths = []
self._finishedPaths = []
self._currScanIdx = None
self._currVolumeNode = None
self.dgPositivePointListNode = None
self.dgPositivePointListNodeObservers = []
self.dgNegativePointListNode = None
self.dgNegativePointListNodeObservers = []
self._prevCatg = None
self._loadingScans = set()
self.pb = None
# status var
self._turninig = False
self._dirty = False
self._syncingCatg = False
self._usingInteractive = False
self._updatingGUIFromParameterNode = False
self._endImportProcessing = False
self._addingControlPoint = False
self._lastTurnNextScan = True
self.init_params()
def setup(self):
"""
Called when the user opens the module the first time and the widget is initialized.
"""
ScriptedLoadableModuleWidget.setup(self)
# Load widget from .ui file (created by Qt Designer).
uiWidget = slicer.util.loadUI(self.resourcePath("UI/EISegMed3D.ui"))
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)
uiWidget.setMRMLScene(slicer.mrmlScene)
# Create logic class. Logic implements all computations that should be possible to run
# in batch mode, without a graphical user interface.
# TODO: we may not need logic. user have to interact
self.logic = EISegMed3DLogic()
# Connections
# These connections ensure that we update parameter node when scene is closed
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.StartCloseEvent,
self.onSceneStartClose)
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.EndCloseEvent,
self.onSceneEndClose)
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.NodeAddedEvent,
self.onSceneEndImport)
# TODO: is syncing settings between node and gui on show/scenestart/... necessary
# button, slider
self.ui.loadModelButton.connect("clicked(bool)", self.loadModelClicked)
self.ui.nextScanButton.connect("clicked(bool)",
lambda p: self.nextScan())
self.ui.prevScanButton.connect("clicked(bool)",
lambda p: self.prevScan())
self.ui.finishScanButton.connect("clicked(bool)", self.finishScan)
self.ui.finishSegmentButton.connect("clicked(bool)",
self.exitInteractiveMode)
self.ui.opacitySlider.connect("valueChanged(double)",
self.opacityUi2Display)
self.ui.dataFolderButton.connect("directoryChanged(QString)",
self.loadScans)
self.ui.skipFinished.connect("clicked(bool)", self.skipFinishedToggled)
iconPath = HERE / "Resources" / "Icons"
self.ui.nextScanButton.setIcon(qt.QIcon(iconPath / "next.png"))
self.ui.prevScanButton.setIcon(qt.QIcon(iconPath / "prev.png"))
self.ui.finishSegmentButton.setIcon(qt.QIcon(iconPath / "done.png"))
self.ui.finishScanButton.setIcon(qt.QIcon(iconPath / "save.png"))
# positive/negative control point
self.ui.dgPositiveControlPointPlacementWidget.setMRMLScene(
slicer.mrmlScene)
self.ui.dgPositiveControlPointPlacementWidget.placeButton(
).toolTip = "Add positive points"
self.ui.dgPositiveControlPointPlacementWidget.placeButton().show()
self.ui.dgPositiveControlPointPlacementWidget.deleteButton(
).setFixedHeight(0) # diable delete point button
self.ui.dgPositiveControlPointPlacementWidget.deleteButton(
).setFixedWidth(0)
self.ui.dgPositiveControlPointPlacementWidget.placeButton().connect(
"clicked(bool)", self.enterInteractiveMode)
self.ui.dgNegativeControlPointPlacementWidget.setMRMLScene(
slicer.mrmlScene)
self.ui.dgNegativeControlPointPlacementWidget.placeButton(
).toolTip = "Add negative points"
self.ui.dgNegativeControlPointPlacementWidget.placeButton().show()
self.ui.dgNegativeControlPointPlacementWidget.deleteButton(
).setFixedHeight(0)
self.ui.dgNegativeControlPointPlacementWidget.deleteButton(
).setFixedWidth(0)
self.ui.dgNegativeControlPointPlacementWidget.placeButton().connect(
"clicked(bool)", self.enterInteractiveMode)
# segment editor
self.ui.embeddedSegmentEditorWidget.setMRMLScene(slicer.mrmlScene)
self.ui.embeddedSegmentEditorWidget.setMRMLSegmentEditorNode(
self.logic.get_segment_editor_node())
self.initializeFromNode()
# Set place point widget colors
self.ui.dgPositiveControlPointPlacementWidget.setNodeColor(
qt.QColor(0, 255, 0))
self.ui.dgNegativeControlPointPlacementWidget.setNodeColor(
qt.QColor(255, 0, 0))
def init_params(self):
"""init changble parameters here"""
self.predictor_params_ = {"norm_radius": 2, "spatial_scale": 1.0}
self.ratio = (
512 / 880, 512 / 880, 12 /
12) # xyz 这个形状与训练的对数据预处理的形状要一致,怎么切换不同模型? todo: 在模块上设置预处理形状。和模型一致
self.train_shape = (512, 512, 12)
self.image_ww = (0, 2650) # low, high range for image crop
self.test_iou = False # the label file need to be set correctly
self.file_suffix = [".nii",
".nii.gz"] # files with these suffix will be loaded
if TEST:
self.device, self.enable_mkldnn = "cpu", True
else:
self.device, self.enable_mkldnn = "gpu", True
def clearScene(self, clearAllVolumes=False):
if clearAllVolumes:
for node in slicer.util.getNodesByClass("vtkMRMLScalarVolumeNode"):
slicer.mrmlScene.RemoveNode(node)
segmentationNode = self.segmentationNode
if segmentationNode is not None:
slicer.mrmlScene.RemoveNode(segmentationNode)
""" progress bar related """
def skipFinishedToggled(self, skipFinished):
self.togglePrevNextBtn(self._currScanIdx)
if not skipFinished and self._currVolumeNode is None:
self.turnTo(self._currScanIdx)
def initPb(self, label="Processing..", windowTitle=None):
if self.pb is None:
self.pb = slicer.util.createProgressDialog()
self.pb.setCancelButtonText("Close")
self.pb.setAutoClose(True)
self.pb.show()
self.pb.activateWindow()
self.pb.setValue(0)
self.pbLeft = 100
else:
self.pbLeft = 100 - self.pb.value
if windowTitle is not None:
self.pb.setWindowTitle(windowTitle)
self.pb.setLabelText(label)
slicer.app.processEvents()
def setPb(self, percentage, label=None, windowTitle=None):
self.pb.setValue(100 - int(self.pbLeft * (1 - percentage)))
if label is not None:
self.pb.setLabelText(label)
if windowTitle is not None:
self.pb.setWindowTitle(windowTitle)
slicer.app.processEvents()
def closePb(self):
if self.pb is None:
return
pb = self.pb
self.pb = None
pb.close()
""" load/change scan related """
def loadScans(self, dataFolder):
"""Get all the scans under a folder and turn to the first one"""
self.initPb("Making sure input valid", "Loading scans")
# 1. ensure valid input
if dataFolder is None or len(dataFolder) == 0:
slicer.util.errorDisplay(
"Please select a Data Folder first!", autoCloseMsec=5000)
return
if not osp.exists(dataFolder):
slicer.util.errorDisplay(
f"The Data Folder( {dataFolder} ) doesn't exist!",
autoCloseMsec=2000)
return
self.clearScene()
self.setPb(0.2, "Searching for scans")
# 2. list files in assigned directory
self._dataFolder = dataFolder
paths = [
p for p in os.listdir(self._dataFolder)
if p[p.find("."):] in self.file_suffix
]
paths = [
p for p in paths if p.split(".")[0][-len("_label"):] != "_label"
]
paths.sort()
self._scanPaths = [osp.join(self._dataFolder, p) for p in paths]
if len(paths) == 0:
self.closePb()
slicer.util.errorDisplay(
f"No file ending with {' or '.join(self.file_suffix)} is found under {self._dataFolder}.\nDid you chose the wrong folder?"
)
return
self.setPb(0.5,
f"Found {len(paths)} scans in folder {self._dataFolder}")
self._currScanIdx, self._finishedPaths = self.getProgress()
self.updateProgressWidgets()
if len(set(self._scanPaths) - set(
self._finishedPaths)) == 0 and self.ui.skipFinished.checked:
self.closePb()
slicer.util.delayDisplay(
f"All {len(self._scanPaths)} scans have been annotated!\nUncheck Skip Finished Scans to browse through them.",
4000, )
return
self.setPb(0.6, "Loading Scan and label")
self._currScanIdx -= 1
found = self.nextScan(silentFail=True)
if not found:
self._currScanIdx += 2
self.prevScan()
self.ui.finishScanButton.setEnabled(True)
logging.info(
f"All scans found under {self._dataFolder} are{','.join([' '+osp.basename(p) for p in self._scanPaths])}"
)
def togglePrevNextBtn(self, currIdx):
if currIdx is None:
return
self.ui.prevScanButton.setEnabled(
self.getTurnToTaskId(currIdx, "prev") is not None)
self.ui.nextScanButton.setEnabled(
self.getTurnToTaskId(currIdx, "next") is not None)
def nextScan(self, silentFail=False):
self.saveSegmentation()
nextIdx = self.getTurnToTaskId(self._currScanIdx, "next")
if nextIdx is None:
self.ui.nextScanButton.setEnabled(False)
if not silentFail:
slicer.util.errorDisplay(
f"This is the last unannotated scan. No next scan")
return False
self._lastTurnNextScan = True
self.turnTo(nextIdx)
return True
def prevScan(self, silentFail=False):
self.saveSegmentation()
prevIdx = self.getTurnToTaskId(self._currScanIdx, "prev")
if prevIdx is None:
self.ui.prevScanButton.setEnabled(False)
if not silentFail:
slicer.util.errorDisplay(
f"This is the first unannotated scan. No previous scan")
return False
self._lastTurnNextScan = False
self.turnTo(prevIdx)
return True
def getTurnToTaskId(self, currIdx, direction, skipFinished=None):
if skipFinished is None:
skipFinished = self.ui.skipFinished.checked
if direction == "next":
while True:
if currIdx >= len(self._scanPaths) - 1:
return None
currIdx += 1
if not skipFinished:
break
if self._scanPaths[currIdx] not in self._finishedPaths:
break
return currIdx
else:
while True:
if currIdx <= 0:
return None
currIdx -= 1
if not skipFinished:
break
if self._scanPaths[currIdx] not in self._finishedPaths:
break
return currIdx
def getScan(self, scanPath, wait=True):
try:
return slicer.util.getNode(osp.basename(scanPath))
except slicer.util.MRMLNodeNotFoundException:
if scanPath in self._loadingScans: # scan hasn't finished loading
if wait:
timeout = 30
while True:
if scanPath not in self._loadingScans:
return slicer.util.getNode(osp.basename(scanPath))
logging.info("waiting", scanPath, timeout)
time.sleep(0.1)
timeout -= 1
if timeout == 0:
return None
else:
return None
else:
if wait:
logging.info(f"loading {scanPath}")
self._loadingScans.add(scanPath)
node = slicer.util.loadVolume(
scanPath,
properties={"show": False,
"singleFile": True})
node.SetName(osp.basename(scanPath))
self._loadingScans.remove(scanPath)
return node
else:
def read(path):
node = slicer.util.loadVolume(
scanPath,
properties={"show": False,
"singleFile": True})
node.SetName(osp.basename(path))
qt.QTimer.singleShot(
random.randint(500, 1000), lambda: read(scanPath))
def manageCache(self, currIdx, skipPreload=False):
toKeepIdxs = [
self.getTurnToTaskId(currIdx, "prev"),
currIdx,
self.getTurnToTaskId(currIdx, "next"),
]
toKeepPaths = [
self._scanPaths[idx] for idx in toKeepIdxs if idx is not None
]
allVolumes = slicer.util.getNodesByClass("vtkMRMLScalarVolumeNode")
for volume in allVolumes:
if volume.GetName() not in map(osp.basename, toKeepPaths):
slicer.mrmlScene.RemoveNode(volume)
for path in toKeepPaths:
self.getScan(path, wait=False)
def turnTo(self, turnToIdx, skipPreload=False):
"""
Turn to the turnToIdx th scan, load scan and label
"""
if turnToIdx == self._currScanIdx:
return False
if self._turninig:
return
self._turninig = True
# 0. clear nodes from previous task and prepare states
self.initPb("Preparing to load", "Load scan and label")
self.setPb(0.1)
if self.segmentation is not None:
self.saveSegmentation()
if self._usingInteractive:
self.exitInteractiveMode()
if len(self._scanPaths) == 0:
slicer.util.errorDisplay(
"No scan found, please load scans first.", autoCloseMsec=2000)
self.closePb()
self._turninig = False
return
logging.info(
f"Turning to the {turnToIdx}th scan, path is {self._scanPaths[turnToIdx]}"
)
self.ui.dgPositiveControlPointPlacementWidget.setEnabled(False)
self.ui.dgNegativeControlPointPlacementWidget.setEnabled(False)
self.clearScene() # remove segmentation node and control points
self._currScanIdx = turnToIdx
slicer.app.processEvents()
# 1. load new scan & preprocess
image_path = self._scanPaths[turnToIdx]
self.setPb(0.2, f"Loading {osp.basename(image_path)}")
self._currVolumeNode = self.getScan(image_path)
self._currVolumeNode.SetName(osp.basename(image_path))
self.manageCache(turnToIdx, skipPreload=skipPreload)
# 2. load segmentation or create an empty one
self.setPb(0.8, "Loading segmentation")
dot_pos = image_path.find(".")
self._currLabelPath = image_path[:dot_pos] + "_label" + image_path[
dot_pos:]
if osp.exists(self._currLabelPath):
segmentNode = slicer.modules.segmentations.logic(
).LoadSegmentationFromFile(self._currLabelPath, False)
else:
segmentNode = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLSegmentationNode")
segmentNode.SetName("EISegMed3DSegmentation")
segmentNode.SetReferenceImageGeometryParameterFromVolumeNode(
self._currVolumeNode)
slicer.app.processEvents()
slicer.app.processEvents()
# update category info
for segment in self.segments:
if segment.GetNameAutoGenerated():
segment.SetName(f"Segment_{segment.GetLabelValue()}")
self._prevCatg = None
self.catgFile2Segmentation()
self.catgSegmentation2File()
for idx, segment in enumerate(self.segments):
# print(f"setting color for segment name: {segment.GetName()}, color: {colors[idx % len(colors)]}")
segment.SetColor(colors[idx % len(colors)])
segment.SetColor(colors[idx % len(colors)])
segment.SetColor(colors[idx % len(colors)])
def sync(*args):
if not self._turninig:
self.catgSegmentation2File()
def setDirty(*args):
if not self._turninig:
self._dirty = True
segmentNode.AddObserver(
segmentNode.GetContentModifiedEvents().GetValue(5), sync)
segmentNode.AddObserver(
segmentNode.GetContentModifiedEvents().GetValue(4), sync)
segmentNode.AddObserver(
segmentNode.GetContentModifiedEvents().GetValue(1), setDirty)
# add: 3, 5
# edit: 5
# remove: 1, 2, 4 (will be triggered when turn task)
# 3. create category label from txt and segmentation
self.setPb(0.8, "Syncing progress")
self.saveProgress()
self.updateProgressWidgets()
# 4. set image
self.setPb(0.9, "Preprocessing image for interactive segmentation")
if not TEST:
self.prepImage()
# 5. set the editor as current result.
self.setPb(0.95, "Wrapping up")
self.ui.embeddedSegmentEditorWidget.setSegmentationNode(segmentNode)
self.ui.embeddedSegmentEditorWidget.setMasterVolumeNode(
self._currVolumeNode)
self.ui.dgPositiveControlPointPlacementWidget.setEnabled(True)
self.ui.dgNegativeControlPointPlacementWidget.setEnabled(True)
# 6. change button state
self.togglePrevNextBtn(self._currScanIdx)
layoutManager = slicer.app.layoutManager()
for sliceViewName in layoutManager.sliceViewNames():
layoutManager.sliceWidget(sliceViewName).mrmlSliceNode(
).RotateToVolumePlane(self._currVolumeNode)
slicer.util.resetSliceViews()
self.closePb()
self._turninig = False
""" category and segmentation management """
@property
def segmentationNode(self):
try:
return slicer.util.getNode("EISegMed3DSegmentation")
except slicer.util.MRMLNodeNotFoundException:
return None
@property
def segmentation(self):
segmentationNode = self.segmentationNode
if segmentationNode is None:
return None
return segmentationNode.GetSegmentation()
@property
def segments(self):
segmentation = self.segmentation
if segmentation is None:
return []
for segId in segmentation.GetSegmentIDs():
yield segmentation.GetSegment(segId)
@property
def configPath(self):
if self._dataFolder is None:
return None
return osp.join(self._dataFolder, "EISegMed3D.json")
def getConfig(self):
skeleton = {"labels": [], "finished": [], "leftOff": ""}
if not osp.exists(self.configPath):
return skeleton
try:
config = json.loads(open(self.configPath, "r").read())
return config
except:
return skeleton
def getSegmentId(self, segment):
segmentation = self.segmentation
for segId in segmentation.GetSegmentIDs():
if segmentation.GetSegment(segId) == segment:
return segId
def getCatgFromFile(self):
"""Parse category info from EISegMed3D.json
Returns:
dict: {name: labelValue, ... }
"""
config = self.getConfig()
catg = {}
for info in config.get("labels", []):
catg[info["name"]] = int(info["labelValue"])
return catg
def catgFile2Segmentation(self):
"""Sync category info from EISegMed3D.json to segmentation
match by labelValue
- create if missing
- correct name if segmentation differes from EISegMed3D.json
"""
if self._syncingCatg:
return
self._syncingCatg = True
# 1. get info from config file
name2value = self.getCatgFromFile()
value2name = {value: name for name, value in name2value.items()}
# 2. set segmentation's names
segmentValues = []
for segment in self.segments:
labelValue = segment.GetLabelValue()
segmentValues.append(labelValue)
name = value2name.get(labelValue, None)
if name is not None:
segment.SetName(name)
# 3. create missing categories
for labelValue in set(value2name.keys()) - set(segmentValues):
segmentId = self.segmentation.AddEmptySegment(
"", value2name[labelValue])
self.segmentation.GetSegment(segmentId).SetLabelValue(labelValue)
self._syncingCatg = False
def catgSegmentation2File(self):
"""Sync category info from segmentation to EISegMed3D.json
match by name
- sync user change name
- sync user add
- sync user delete
"""
if self._syncingCatg:
return
self._syncingCatg = True
# 1. if no prev catg record, record current
segmentation = self.segmentation
if self._prevCatg is None:
self._prevCatg = {
segId: segmentation.GetSegment(segId).GetName()
for segId in segmentation.GetSegmentIDs()
}
# 2. change name, add to file or delete
name2value = self.getCatgFromFile()
for segmentId in segmentation.GetSegmentIDs():
segment = segmentation.GetSegment(segmentId)
# change name
if segmentId in self._prevCatg.keys() and segment.GetName(
) != self._prevCatg[segmentId]:
del name2value[self._prevCatg[segmentId]]
name2value[segment.GetName()] = segment.GetLabelValue()
# user add or this segmentation have more catgs
if segment.GetName() not in name2value.keys():
if segment.GetNameAutoGenerated():
segment.SetName(f"Segment_{segment.GetLabelValue()}")
name2value[segment.GetName()] = segment.GetLabelValue()
# delete
for segmentId in set(self._prevCatg.keys()) - set(
segmentation.GetSegmentIDs()):
logging.info(
f"deleting segment {segmentId} {self.segmentation.GetSegment(segmentId)}"
)
del name2value[self._prevCatg[segmentId]]
# 3. record catg info
self._prevCatg = {
segId: segmentation.GetSegment(segId).GetName()
for segId in segmentation.GetSegmentIDs()
}
# 4. write to file
config = self.getConfig()
config["labels"] = [{
"name": name,
"labelValue": value
} for name, value in name2value.items()]
print(json.dumps(config), file=open(self.configPath, "w"))
self._syncingCatg = False
""" task progress related """
def saveProgress(self):
config = self.getConfig()
relpath = lambda path: osp.relpath(path, self._dataFolder)
config["finished"] = [relpath(p) for p in self._finishedPaths]
config["leftOff"] = relpath(self._scanPaths[self._currScanIdx])
print(json.dumps(config), file=open(self.configPath, "w"))
def getProgress(self):
config = self.getConfig()
leftOffIdx = 0
if "leftOff" in config.keys():
for idx, p in enumerate(self._scanPaths):
if p == osp.join(self._dataFolder, config["leftOff"]):
leftOffIdx = idx
return leftOffIdx, [
osp.join(self._dataFolder, p) for p in config.get("finished", [])
]
def updateProgressWidgets(self):
self.ui.annProgressBar.setValue(
int(100 * len(self._finishedPaths) / len(self._scanPaths)))
self.ui.progressDetail.setText(
f"Finished: {len(self._finishedPaths)} / Total: {len(self._scanPaths)}"
)
def toggleFinished(idx, *args):
logging.info(idx, *args)
if self._scanPaths[idx] in self._finishedPaths:
self._finishedPaths.remove(self._scanPaths[idx])
else:
self._finishedPaths.append(self._scanPaths[idx])
self.saveProgress()
self.updateProgressWidgets()
def pathDoubleClicked(row, col):
if self._currScanIdx == row:
return
if col != 1:
return
self.turnTo(row, skipPreload=True)
table = self.ui.progressTable
table.setRowCount(len(self._scanPaths))
for idx, path in enumerate(self._scanPaths):
layout = qt.QVBoxLayout()
checkbox = qt.QCheckBox()
checkbox.setChecked(path in self._finishedPaths)
checkbox.toggled.connect(partial(toggleFinished, idx))
layout.addWidget(checkbox)
wrapper = qt.QWidget()
wrapper.setLayout(layout)
table.setCellWidget(idx, 0, wrapper)
table.setItem(
idx, 1,
qt.QTableWidgetItem(osp.relpath(path, self._dataFolder)))
table.cellDoubleClicked.connect(pathDoubleClicked)
table.resizeColumnsToContents()
# ugly fix. second colum wont strength after setting data
self.ui.progressCollapse.toggle()
self.ui.progressCollapse.toggle()
self.togglePrevNextBtn(self._currScanIdx)
""" control point related """
def createPointListNode(self, name, onMarkupNodeModified, color):
displayNode = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLMarkupsDisplayNode")
displayNode.SetTextScale(0)
displayNode.SetSelectedColor(color)
pointListNode = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLMarkupsFiducialNode")
pointListNode.SetName(name)
pointListNode.SetAndObserveDisplayNodeID(displayNode.GetID())
pointListNodeObservers = []
self.addPointListNodeObserver(pointListNode, onMarkupNodeModified)
return pointListNode, pointListNodeObservers
def removePointListNodeObservers(self, pointListNode,
pointListNodeObservers):
if pointListNode and pointListNodeObservers:
for observer in pointListNodeObservers:
pointListNode.RemoveObserver(observer)
def addPointListNodeObserver(self, pointListNode, onMarkupNodeModified):
pointListNodeObservers = []
if pointListNode:
eventIds = [slicer.vtkMRMLMarkupsNode.PointPositionDefinedEvent]
for eventId in eventIds:
pointListNodeObservers.append(
pointListNode.AddObserver(eventId, onMarkupNodeModified))
return pointListNodeObservers
def getControlPointsXYZ(self, pointListNode, name):
v = self._currVolumeNode
RasToIjkMatrix = vtk.vtkMatrix4x4()
v.GetRASToIJKMatrix(RasToIjkMatrix)
point_set = []
n = pointListNode.GetNumberOfControlPoints()
for i in range(n):
coord = pointListNode.GetNthControlPointPosition(i)
world = [0, 0, 0]
pointListNode.GetNthControlPointPositionWorld(i, world)
p_Ras = [coord[0], coord[1], coord[2], 1.0]
p_Ijk = RasToIjkMatrix.MultiplyDoublePoint(p_Ras)
p_Ijk = [round(i) for i in p_Ijk]
point_set.append(p_Ijk[0:3])
logging.info(f"{name} => Current control points: {point_set}")
return point_set
def getControlPointXYZ(self, pointListNode, index):
v = self._currVolumeNode
RasToIjkMatrix = vtk.vtkMatrix4x4()
v.GetRASToIJKMatrix(RasToIjkMatrix)
coord = pointListNode.GetNthControlPointPosition(index)
world = [0, 0, 0]
pointListNode.GetNthControlPointPositionWorld(index, world)
p_Ras = [coord[0], coord[1], coord[2], 1.0]
p_Ijk = RasToIjkMatrix.MultiplyDoublePoint(p_Ras)
p_Ijk = [round(i) for i in p_Ijk]
return p_Ijk[0:3]
def resetPointList(self, markupsPlaceWidget, pointListNode,
pointListNodeObservers):
if markupsPlaceWidget.placeModeEnabled:
markupsPlaceWidget.setPlaceModeEnabled(False)
if pointListNode:
slicer.mrmlScene.RemoveNode(pointListNode)
self.removePointListNodeObservers(pointListNode,
pointListNodeObservers)
def removePointListNodeObservers(self, pointListNode,
pointListNodeObservers):
if pointListNode and pointListNodeObservers:
for observer in pointListNodeObservers:
pointListNode.RemoveObserver(observer)
def enterInteractiveMode(self):
if self._usingInteractive:
return
segmentation = self.segmentation
segmentId = self.ui.embeddedSegmentEditorWidget.currentSegmentID()
segment = segmentation.GetSegment(segmentId)
if len(segmentation.GetSegmentIDs()) == 0 or len(
segmentId) == 0: # no segment or currently no active segment
segmentId = segmentation.AddEmptySegment("")
else:
if (slicer.util.arrayFromSegmentBinaryLabelmap(
self.segmentationNode, segmentId,
self._currVolumeNode).sum() != 0):
# TODO: prompt and let user choose whether to create new segment
segmentId = segmentation.AddEmptySegment("",
segment.GetName(),
segment.GetColor())
self.ui.embeddedSegmentEditorWidget.setCurrentSegmentID(segmentId)
if not TEST:
self.setImage()
self.clicker = Clicker()
# TODO: scroll to the new segment
self.ui.embeddedSegmentEditorWidget.setDisabled(True)
self.ui.finishSegmentButton.setEnabled(True)
self._usingInteractive = True
def exitInteractiveMode(self):
self.ui.dgPositiveControlPointPlacementWidget.deleteAllPoints()
self.ui.dgNegativeControlPointPlacementWidget.deleteAllPoints()
self.ui.dgPositiveControlPointPlacementWidget.placeButton().setChecked(
False)
self.ui.dgNegativeControlPointPlacementWidget.placeButton().setChecked(
False)
self.ui.embeddedSegmentEditorWidget.setDisabled(False)
self.ui.finishSegmentButton.setEnabled(False)
self._usingInteractive = False
""" inference related """
def loadModelClicked(self):
model_path, param_path = self.ui.modelPathInput.currentPath, self.ui.paramPathInput.currentPath
if not model_path or not param_path:
slicer.util.errorDisplay(
"Please set the model_path and parameter path before load model."
)
return
self.inference_predictor = predictor.BasePredictor(
model_path,
param_path,
device=self.device,
enable_mkldnn=self.enable_mkldnn,
**self.predictor_params_)
slicer.util.delayDisplay(
"Sucessfully loaded model to {}!".format(self.device),
autoCloseMsec=1500)
def onControlPointAdded(self, observer, eventid):
if self._addingControlPoint:
return
self._addingControlPoint = True
self.initPb("Entering interactive mode", "Doing Inference")
if not self._usingInteractive:
self.enterInteractiveMode()
# 1. get new point pos and type
self.setPb(0.1, "Preparing interactive segment")
posPoints = self.getControlPointsXYZ(self.dgPositivePointListNode,
"positive")
negPoints = self.getControlPointsXYZ(self.dgNegativePointListNode,
"negative")
newPointIndex = observer.GetDisplayNode().GetActiveControlPoint()
logging.info("newPointIndex", newPointIndex)
newPointPos = self.getControlPointXYZ(observer, newPointIndex)
isPositivePoint = False if len(
posPoints) == 0 else newPointPos == posPoints[-1]
logging.info(
f"{['Negative', 'Positive'][int(isPositivePoint)]} point added at {newPointPos}"
)
# 2. ensure current segment empty, create if not
segmentation = self.segmentation
segmentId = self.ui.embeddedSegmentEditorWidget.currentSegmentID()
segment = segmentation.GetSegment(segmentId)
logging.info(
f"Current segment: {self.getSegmentId(segment)} {segment.GetName()} {segment.GetLabelValue()}",
)
with slicer.util.tryWithErrorDisplay(
"Failed to run inference.", waitCursor=True):
self.setPb(0.2, "Running inference")
# predict image for test
if TEST:
p = newPointPos
p = [p[2], p[1], p[0]]
res = slicer.util.arrayFromSegmentBinaryLabelmap(
self.segmentationNode, segmentId, self._currVolumeNode)
mask = np.zeros_like(res)
mask[p[0] - 10:p[0] + 10, p[1] - 10:p[1] + 10, p[2] - 10:p[2] +
10] = 1
else:
paddle.device.cuda.empty_cache()
mask = self.infer_image(
newPointPos, isPositivePoint) # (880, 880, 12) same as res
self.setPb(0.9, "Wrapping up")
# set new numpy mask to segmentation
slicer.util.updateSegmentBinaryLabelmapFromArray(
mask, self.segmentationNode, segmentId, self._currVolumeNode)
if self.test_iou:
label = sitk.ReadImage(self._currLabelPath)
label = sitk.GetArrayFromImage(label).astype("int32")
iou = self.get_iou(label, mask, newPointPos)
logging.info("Current IOU is {}".format(iou))
self.closePb()
self._addingControlPoint = False
def get_iou(self, gt_mask, pred_mask, newPointPos, ignore_label=-1):
ignore_gt_mask_inv = gt_mask != ignore_label
pred_mask = pred_mask == 1
obj_gt_mask = gt_mask == gt_mask[newPointPos[2], newPointPos[1],
newPointPos[0]]
intersection = np.logical_and(
np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
union = np.logical_and(
np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
return intersection / union
def prepImage(self):
self.origin = sitk.ReadImage(self._scanPaths[self._currScanIdx])
itk_img_res = inference.crop_wwwc(
self.origin, max_v=self.image_ww[1], min_v=self.image_ww[0]
) # same as the preprocess when you train the model (512, 512, 12) WHD
itk_img_res, self.new_spacing = inference.resampleImage(
itk_img_res, out_size=self.train_shape) # origin: (880, 880, 12)
npy_img = sitk.GetArrayFromImage(itk_img_res).astype(
"float32") # 12, 512, 512 DHW
# Exchange dim and normalize
input_data = np.expand_dims(np.transpose(npy_img, [2, 1, 0]), axis=0)
if input_data.max() > 0:
input_data = input_data / input_data.max()
self.input_data = input_data
def setImage(self):
logging.info(
f"输入网络前数据的形状:{self.input_data.shape}") # shape (1, 512, 512, 12)
try:
self.inference_predictor.set_input_image(self.input_data)
except AttributeError:
slicer.util.errorDisplay(
"Please load model first", autoCloseMsec=1200)
def infer_image(self,
click_position=None,
positive_click=True,
pred_thr=0.49):
"""
click_position: one or serveral clicks represent by list like: [[234, 284, 7]]
positive_click: whether this click is positive or negative
"""
try:
paddle.device.set_device(self.device)
except AttributeError:
slicer.util.errorDisplay(
"Model is not loaded. Please load model first")
return
except ValueError:
slicer.util.errorDisplay(
"The AI-assisted image infer process need to be run on gpu device, please install paddle with GPU enabled."
)
tic = time.time()
self.prepare_click(click_position, positive_click)
with paddle.no_grad():
pred_probs = self.inference_predictor.get_prediction_noclicker(
self.clicker)
output_data = (pred_probs > pred_thr) * pred_probs # (12, 512, 512) DHW
output_data[output_data > 0] = 1
# Load mask from model infer result, and change from numpy to simpleitk
output_data = np.transpose(output_data, [2, 1, 0])
mask_itk_new = sitk.GetImageFromArray(output_data) # (512, 512, 12) WHD
mask_itk_new.SetSpacing(self.new_spacing)
mask_itk_new.SetOrigin(self.origin.GetOrigin())
mask_itk_new.SetDirection(self.origin.GetDirection())
mask_itk_new = sitk.Cast(mask_itk_new, sitk.sitkUInt8)
# if need max connect opponet filter, add it before here.
Mask, _ = inference.resampleImage(mask_itk_new,
self.origin.GetSize(),
self.origin.GetSpacing(),
sitk.sitkNearestNeighbor)
Mask.CopyInformation(self.origin)
npy_img = sitk.GetArrayFromImage(Mask).astype(
"float32") # 12, 512, 512 DHW
logging.info(
f"预测结果的形状:{output_data.shape}, 预测时间为 {(time.time() - tic) * 1000} ms"
) # shape (12, 512, 512) DHW test
return npy_img
def prepare_click(self, click_position, positive_click):
click_position_new = []
for i, v in enumerate(click_position):
click_position_new.append(int(self.ratio[i] * click_position[i]))
if positive_click:
click_position_new.append(100)
else:
click_position_new.append(-100)
logging.info("The {} click is click on {} (resampled)".format(
["negative", "positive"][positive_click],
click_position_new)) # result is correct
click = inference.Click(
is_positive=positive_click, coords=click_position_new)
self.clicker.add_click(click)
logging.info("####################### clicker length",
len(self.clicker.clicks_list))
""" saving related """
def finishScan(self):
if self._usingInteractive:
self.exitInteractiveMode()
self.saveSegmentation()
self._finishedPaths.append(self._scanPaths[self._currScanIdx])
self.saveProgress()
self.updateProgressWidgets()
if self._lastTurnNextScan:
self.nextScan()
else:
self.prevScan()
def saveSegmentation(self):
"""
save segmentation mask to self._dataFolder
"""
tic = time.time()
if not self._dirty:
logging.info("Segmentation not changed, skip saving")
slicer.app.processEvents()
return
catgs = self.getCatgFromFile()
segmentationNode = self.segmentationNode
logging.info("segmentationNode", segmentationNode)
segmentation = segmentationNode.GetSegmentation()
# 2. prepare save path
scanPath = self._scanPaths[self._currScanIdx]
dotPos = scanPath.find(".")
labelPath = scanPath[:dotPos] + "_label" + scanPath[dotPos:]
# 3. save
colorTableNode = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLColorTableNode")
colorTableNode.SetTypeToUser()
colorTableNode.SetNumberOfColors(len(self.segmentation.GetSegmentIDs()))
colorTableNode.UnRegister(None)
colorTableNode.SetNamesInitialised(True)
for segment in self.segments:
colorTableNode.SetColor(catgs[segment.GetName()],
segment.GetName(), *segment.GetColor(), 1.0)
labelmapVolumeNode = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLLabelMapVolumeNode")
slicer.modules.segmentations.logic().ExportSegmentsToLabelmapNode(
segmentationNode,
segmentation.GetSegmentIDs(),
labelmapVolumeNode,
self._currVolumeNode,
segmentation.EXTENT_UNION_OF_EFFECTIVE_SEGMENTS,
colorTableNode, )
res = slicer.util.saveNode(labelmapVolumeNode, labelPath)
# clean up useless nodes
slicer.mrmlScene.RemoveNode(labelmapVolumeNode)
if res:
logging.info(f"{labelPath.split('/')[-1]} save successfully.")
else:
slicer.util.errorDisplay(f"{labelPath.split('/')[-1]} save failed!")
self._dirty = False
logging.info(f"saving took {time.time() - tic}s")
""" display related """
def opacityUi2Display(self):
segmentationNode = self.segmentationNode
if segmentationNode is None:
return
threshold = self.ui.opacitySlider.value
displayNode = segmentationNode.GetDisplayNode()
displayNode.SetOpacity3D(threshold) # Set opacity for 3d render
displayNode.SetOpacity(threshold) # Set opacity for 2d
def opacityDisplay2Ui(self):
segmentationNode = self.segmentationNode
if segmentationNode is not None:
displayNode = segmentationNode.GetDisplayNode()
if displayNode is not None:
opacity = displayNode.GetOpacity()
self.ui.opacitySlider.value = opacity
""" life cycle related """
def cleanup(self):
"""
Called when the application closes and the module widget is destroyed.
"""
self._turninig = True
self.clearScene(clearAllVolumes=True)
self.removeObservers()
self.resetPointList(
self.ui.dgPositiveControlPointPlacementWidget,
self.dgPositivePointListNode,
self.dgPositivePointListNodeObservers, )
self.dgPositivePointListNode = None
self.resetPointList(
self.ui.dgNegativeControlPointPlacementWidget,
self.dgNegativePointListNode,
self.dgNegativePointListNodeObservers, )
self.dgNegativePointListNode = None
def enter(self):
"""
Called each time the user opens this module. Not when reload/switch back.
"""
def exit(self):
"""
Called each time the user opens a different module.
"""
# Do not react to parameter node changes (GUI wlil be updated when the user enters into the module)
self.removeObserver(
self._parameterNode,
vtk.vtkCommand.ModifiedEvent,
self.updateGUIFromParameterNode, )
def onReload(self):
self.cleanup()
super().onReload()
def onSceneEndImport(self, caller, event):
"""
Called after reload and after scan/segmentation is imported
"""
if self._endImportProcessing:
return
self._endImportProcessing = True
self._endImportProcessing = False
def onSceneStartClose(self, caller, event):
"""
Called just before the scene is closed.
"""
self._turninig = True
# Parameter node will be reset, do not use it anymore
self.saveProgress()
self.setParameterNode(None)
self.resetPointList(
self.ui.dgPositiveControlPointPlacementWidget,
self.dgPositivePointListNode,
self.dgPositivePointListNodeObservers, )
self.dgPositivePointListNode = None
self.resetPointList(
self.ui.dgNegativeControlPointPlacementWidget,
self.dgNegativePointListNode,
self.dgNegativePointListNodeObservers, )
self.dgNegativePointListNode = None
def onSceneEndClose(self, caller, event):
"""
Called just after the scene is closed.
"""
# If this module is shown while the scene is closed then recreate a new parameter node immediately
if self.parent.isEntered:
self.initializeParameterNode()
def initializeFromNode(self):
"""
Ensure parameter node exists and observed.
"""
# Parameter node stores all user choices in parameter values, node selections, etc.
# so that when the scene is saved and reloaded, these settings are restored.
self.setParameterNode(self.logic.getParameterNode())
segNode = self.segmentationNode
if segNode is not None:
self.ui.opacitySlider.setValue(segNode.GetDisplayNode().GetOpacity(
))
def setParameterNode(self, inputParameterNode):
"""
Set and observe parameter node.
Observation is needed because when the parameter node is changed then the GUI must be updated immediately.
"""
# Unobserve previously selected parameter node and add an observer to the newly selected.
# Changes of parameter node are observed so that whenever parameters are changed by a script or any other module
# those are reflected immediately in the GUI.
if self._parameterNode is not None:
self.removeObserver(
self._parameterNode,
vtk.vtkCommand.ModifiedEvent,
self.updateGUIFromParameterNode, )
self._parameterNode = inputParameterNode
if self._parameterNode is not None:
self.addObserver(
self._parameterNode,
vtk.vtkCommand.ModifiedEvent,
self.updateGUIFromParameterNode, )
# Initial GUI update
self.updateGUIFromParameterNode()
def updateGUIFromParameterNode(self, caller=None, event=None):
"""
This method is called whenever parameter node is changed.
The module GUI is updated to show the current state of the parameter node.
"""
if self._parameterNode is None or self._updatingGUIFromParameterNode:
return
# Make sure GUI changes do not call updateParameterNodeFromGUI (it could cause infinite loop)
self._updatingGUIFromParameterNode = True
if not self.dgPositivePointListNode:
(
self.dgPositivePointListNode,
self.dgPositivePointListNodeObservers,
) = self.createPointListNode("P", self.onControlPointAdded,
[0.5, 1, 0.5])
self.ui.dgPositiveControlPointPlacementWidget.setCurrentNode(
self.dgPositivePointListNode)
self.ui.dgPositiveControlPointPlacementWidget.setPlaceModeEnabled(
False)
if not self.dgNegativePointListNode:
(
self.dgNegativePointListNode,
self.dgNegativePointListNodeObservers,
) = self.createPointListNode("P", self.onControlPointAdded,
[0.5, 1, 0.5])
self.ui.dgNegativeControlPointPlacementWidget.setCurrentNode(
self.dgNegativePointListNode)
self.ui.dgNegativeControlPointPlacementWidget.setPlaceModeEnabled(
False)
# All the GUI updates are done
self._updatingGUIFromParameterNode = False
def updateParameterNodeFromGUI(self, caller=None, event=None):
"""
This method is called when the user makes any change in the GUI.
The changes are saved into the parameter node (so that they are restored when the scene is saved and loaded).
"""
if self._parameterNode is None or self._updatingGUIFromParameterNode:
return
wasModified = self._parameterNode.StartModify(
) # Modify all properties in a single batch
self._parameterNode.EndModify(wasModified)
#
# EISegMed3DLogic
#
class EISegMed3DLogic(ScriptedLoadableModuleLogic):
"""This class should implement all the actual
computation done by your module. The interface
should be such that other python code can import
this class and make use of the functionality without
requiring an instance of the Widget.
Uses ScriptedLoadableModuleLogic base class, available at:
https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
"""
def __init__(self):
"""
Called when the logic class is instantiated. Can be used for initializing member variables.
"""
ScriptedLoadableModuleLogic.__init__(self)
def setDefaultParameters(self, parameterNode):
"""
Initialize parameter node with default settings.
"""
if not parameterNode.GetParameter("Threshold"):
parameterNode.SetParameter("Threshold", "100.0")
def process(self,
inputVolume,
outputVolume,
imageThreshold,
invert=False,
showResult=True):
"""
Run the processing algorithm.
Can be used without GUI widget.
:param inputVolume: volume to be thresholded
:param outputVolume: thresholding result
:param imageThreshold: values above/below this threshold will be set to 0
:param invert: if True then values above the threshold will be set to 0, otherwise values below are set to 0
:param showResult: show output volume in slice viewers
"""
if not inputVolume or not outputVolume:
raise ValueError("Input or output volume is invalid")
import time
startTime = time.time()
logging.info("Processing started")
# Compute the thresholded output volume using the "Threshold Scalar Volume" CLI module
cliParams = {
"InputVolume": inputVolume.GetID(),
"OutputVolume": outputVolume.GetID(),
"ThresholdValue": imageThreshold,
"ThresholdType": "Above" if invert else "Below",
}
cliNode = slicer.cli.run(
slicer.modules.thresholdscalarvolume,
None,
cliParams,
wait_for_completion=True,
update_display=showResult, )
# We don't need the CLI module node anymore, remove it to not clutter the scene with it
slicer.mrmlScene.RemoveNode(cliNode)
stopTime = time.time()
logging.info(
f"Processing completed in {stopTime - startTime:.2f} seconds")
def get_segment_editor_node(self):
# Use the Segment Editor module's parameter node for the embedded segment editor widget.
# This ensures that if the user switches to the Segment Editor then the selected
# segmentation node, volume node, etc. are the same.
segmentEditorSingletonTag = "SegmentEditor"
segmentEditorNode = slicer.mrmlScene.GetSingletonNode(
segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode")
if segmentEditorNode is None:
segmentEditorNode = slicer.mrmlScene.CreateNodeByClass(
"vtkMRMLSegmentEditorNode")
segmentEditorNode.UnRegister(None)
segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag)
segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode)
return segmentEditorNode
#
# EISegMed3DTest
#
class EISegMed3DTest(ScriptedLoadableModuleTest):
"""
This is the test case for your scripted module.
Uses ScriptedLoadableModuleTest base class, available at:
https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
"""
def setUp(self):
"""Do whatever is needed to reset the state - typically a scene clear will be enough."""
slicer.mrmlScene.Clear()
def runTest(self):
"""Run as few or as many tests as needed here."""
self.setUp()
self.test_EISegMed3D1()
def test_EISegMed3D1(self):
"""Ideally you should have several levels of tests. At the lowest level
tests should exercise the functionality of the logic with different inputs
(both valid and invalid). At higher levels your tests should emulate the
way the user would interact with your code and confirm that it still works
the way you intended.
One of the most important features of the tests is that it should alert other
developers when their changes will have an impact on the behavior of your
module. For example, if a developer removes a feature that you depend on,
your test should break so they know that the feature is needed.
"""
self.delayDisplay("Starting the test")
# Get/create input data
import SampleData
registerSampleData()
inputVolume = SampleData.downloadSample("EISegMed3D1")
self.delayDisplay("Loaded test data set")
inputScalarRange = inputVolume.GetImageData().GetScalarRange()
self.assertEqual(inputScalarRange[0], 0)
self.assertEqual(inputScalarRange[1], 695)
outputVolume = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLScalarVolumeNode")
threshold = 100
# Test the module logic
logic = EISegMed3DLogic()
# Test algorithm with non-inverted threshold
logic.process(inputVolume, outputVolume, threshold, True)
outputScalarRange = outputVolume.GetImageData().GetScalarRange()
self.assertEqual(outputScalarRange[0], inputScalarRange[0])
self.assertEqual(outputScalarRange[1], threshold)
# Test algorithm with inverted threshold
logic.process(inputVolume, outputVolume, threshold, False)
outputScalarRange = outputVolume.GetImageData().GetScalarRange()
self.assertEqual(outputScalarRange[0], inputScalarRange[0])
self.assertEqual(outputScalarRange[1], inputScalarRange[1])
self.delayDisplay("Test passed")
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>placePoint</class>
<widget class="qMRMLWidget" name="placePoint">
<property name="enabled">
<bool>true</bool>
</property>
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>682</width>
<height>847</height>
</rect>
</property>
<property name="sizePolicy">
<sizepolicy hsizetype="Preferred" vsizetype="Preferred">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<widget class="ctkCollapsibleButton" name="modelSettingsCollapse">
<property name="text">
<string>Model Settings</string>
</property>
<property name="collapsed">
<bool>false</bool>
</property>
<property name="collapsedHeight">
<number>9</number>
</property>
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<layout class="QFormLayout" name="modelLayout">
<item row="0" column="0">
<widget class="QLabel" name="modelPathLabel">
<property name="text">
<string>Model Path: </string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="ctkPathLineEdit" name="modelPathInput" />
</item>
<item row="1" column="0">
<widget class="QLabel" name="paramPathLabel">
<property name="text">
<string>Param Path: </string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="ctkPathLineEdit" name="paramPathInput" />
</item>
</layout>
</item>
<item>
<layout class="QHBoxLayout" name="loadModelLayout">
<item>
<widget class="QPushButton" name="loadModelButton">
<property name="text">
<string>Load Static Model</string>
</property>
</widget>
</item>
</layout>
</item>
</layout>
</widget>
</item>
<item>
<widget class="Line" name="line_2">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
</widget>
</item>
<item>
<layout class="QHBoxLayout" name="dataFolderLayout">
<item>
<widget class="QLabel" name="dataFolderLabel">
<property name="text">
<string>Data Folder :</string>
</property>
</widget>
</item>
<item>
<widget class="ctkDirectoryButton" name="dataFolderButton">
<property name="options">
<set>ctkDirectoryButton::HideNameFilterDetails|ctkDirectoryButton::ReadOnly</set>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="QCheckBox" name="skipFinished">
<property name="text">
<string>Skip Finished Scans</string>
</property>
<property name="checked">
<bool>false</bool>
</property>
</widget>
</item>
<item>
<layout class="QHBoxLayout" name="prevNextLayout">
<item>
<widget class="QPushButton" name="prevScanButton">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Prev Scan</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="nextScanButton">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Next Scan</string>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="Line" name="line_3">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
</widget>
</item>
<item>
<layout class="QHBoxLayout" name="placePointLayout">
<item>
<widget class="QLabel" name="positiveLabel">
<property name="text">
<string>Positive Point:</string>
</property>
</widget>
</item>
<item>
<widget class="qSlicerMarkupsPlaceWidget" name="dgPositiveControlPointPlacementWidget">
<property name="enabled">
<bool>false</bool>
</property>
<property name="buttonsVisible">
<bool>false</bool>
</property>
<property name="placeMultipleMarkups">
<enum>qSlicerMarkupsPlaceWidget::ForcePlaceMultipleMarkups</enum>
</property>
<property name="nodeColor">
<color>
<red>0</red>
<green>1</green>
<blue>0</blue>
</color>
</property>
<property name="currentNodeActive">
<bool>false</bool>
</property>
<property name="placeModeEnabled">
<bool>false</bool>
</property>
<property name="placeModePersistency">
<bool>false</bool>
</property>
<property name="deleteAllMarkupsOptionVisible">
<bool>false</bool>
</property>
</widget>
</item>
<item>
<widget class="QLabel" name="negativeLabel">
<property name="text">
<string>Negative Point:</string>
</property>
</widget>
</item>
<item>
<widget class="qSlicerMarkupsPlaceWidget" name="dgNegativeControlPointPlacementWidget">
<property name="enabled">
<bool>false</bool>
</property>
<property name="buttonsVisible">
<bool>false</bool>
</property>
<property name="deleteAllControlPointsOptionVisible">
<bool>false</bool>
</property>
<property name="placeMultipleMarkups">
<enum>qSlicerMarkupsPlaceWidget::ForcePlaceMultipleMarkups</enum>
</property>
</widget>
</item>
</layout>
</item>
<item>
<layout class="QHBoxLayout" name="finishLayout">
<item>
<widget class="QPushButton" name="finishSegmentButton">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Finish Segment</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="finishScanButton">
<property name="enabled">
<bool>false</bool>
</property>
<property name="text">
<string>Finish Scan</string>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="Line" name="line_5">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
</widget>
</item>
<item>
<widget class="Line" name="line">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
</widget>
</item>
<item>
<widget class="ctkCollapsibleButton" name="segmentEditorCollapse">
<property name="text">
<string>Segment Editor</string>
</property>
<property name="collapsed">
<bool>false</bool>
</property>
<property name="collapsedHeight">
<number>9</number>
</property>
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<widget class="qMRMLSegmentEditorWidget" name="embeddedSegmentEditorWidget">
<property name="sizePolicy">
<sizepolicy hsizetype="Preferred" vsizetype="Expanding">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="segmentationNodeSelectorVisible">
<bool>false</bool>
</property>
<property name="sourceVolumeNodeSelectorVisible">
<bool>false</bool>
</property>
<property name="masterVolumeNodeSelectorVisible">
<bool>false</bool>
</property>
<property name="switchToSegmentationsButtonVisible">
<bool>true</bool>
</property>
<property name="effectColumnCount">
<number>3</number>
</property>
<property name="unorderedEffectsVisible">
<bool>false</bool>
</property>
<property name="jumpToSelectedSegmentEnabled">
<bool>true</bool>
</property>
</widget>
</item>
</layout>
</widget>
</item>
<item>
<layout class="QHBoxLayout" name="opacityLayout">
<item>
<widget class="QLabel" name="threshLabel">
<property name="text">
<string>Segment Opacity: </string>
</property>
</widget>
</item>
<item>
<widget class="ctkSliderWidget" name="opacitySlider">
<property name="sizePolicy">
<sizepolicy hsizetype="Preferred" vsizetype="Fixed">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="decimals">
<number>2</number>
</property>
<property name="singleStep">
<double>0.010000000000000</double>
</property>
<property name="pageStep">
<double>0.100000000000000</double>
</property>
<property name="maximum">
<double>1.000000000000000</double>
</property>
<property name="value">
<double>0.900000000000000</double>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="ctkCollapsibleButton" name="progressCollapse">
<property name="text">
<string>Progress</string>
</property>
<property name="collapsed">
<bool>true</bool>
</property>
<property name="collapsedHeight">
<number>9</number>
</property>
<layout class="QVBoxLayout" name="progressLayout">
<item>
<layout class="QHBoxLayout" name="horizontalLayout">
<item>
<widget class="QLabel" name="progressLabel">
<property name="text">
<string>Annotation Progress: </string>
</property>
</widget>
</item>
<item>
<widget class="QProgressBar" name="annProgressBar">
<property name="value">
<number>0</number>
</property>
</widget>
</item>
<item>
<widget class="QLabel" name="progressDetail">
<property name="text">
<string />
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="QTableWidget" name="progressTable">
<property name="sizePolicy">
<sizepolicy hsizetype="Minimum" vsizetype="Minimum">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="focusPolicy">
<enum>Qt::NoFocus</enum>
</property>
<property name="sizeAdjustPolicy">
<enum>QAbstractScrollArea::AdjustIgnored</enum>
</property>
<property name="editTriggers">
<set>QAbstractItemView::NoEditTriggers</set>
</property>
<property name="alternatingRowColors">
<bool>true</bool>
</property>
<property name="sortingEnabled">
<bool>false</bool>
</property>
<property name="columnCount">
<number>2</number>
</property>
<attribute name="horizontalHeaderVisible">
<bool>false</bool>
</attribute>
<attribute name="horizontalHeaderCascadingSectionResizes">
<bool>true</bool>
</attribute>
<attribute name="horizontalHeaderShowSortIndicator" stdset="0">
<bool>false</bool>
</attribute>
<attribute name="horizontalHeaderStretchLastSection">
<bool>true</bool>
</attribute>
<attribute name="verticalHeaderVisible">
<bool>false</bool>
</attribute>
<column>
<property name="text">
<string>Finished</string>
</property>
</column>
<column>
<property name="text">
<string>Scan Name</string>
</property>
</column>
</widget>
</item>
</layout>
</widget>
</item>
<item>
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeType">
<enum>QSizePolicy::Expanding</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</property>
</spacer>
</item>
</layout>
</widget>
<customwidgets>
<customwidget>
<class>ctkCollapsibleButton</class>
<extends>QWidget</extends>
<header>ctkCollapsibleButton.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>ctkDirectoryButton</class>
<extends>QWidget</extends>
<header>ctkDirectoryButton.h</header>
</customwidget>
<customwidget>
<class>ctkPathLineEdit</class>
<extends>QWidget</extends>
<header>ctkPathLineEdit.h</header>
</customwidget>
<customwidget>
<class>ctkSliderWidget</class>
<extends>QWidget</extends>
<header>ctkSliderWidget.h</header>
</customwidget>
<customwidget>
<class>qMRMLWidget</class>
<extends>QWidget</extends>
<header>qMRMLWidget.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>qSlicerWidget</class>
<extends>QWidget</extends>
<header>qSlicerWidget.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>qSlicerMarkupsPlaceWidget</class>
<extends>qSlicerWidget</extends>
<header>qSlicerMarkupsPlaceWidget.h</header>
</customwidget>
<customwidget>
<class>qMRMLSegmentEditorWidget</class>
<extends>qMRMLWidget</extends>
<header>qMRMLSegmentEditorWidget.h</header>
</customwidget>
</customwidgets>
<resources />
<connections />
</ui>
#slicer_add_python_unittest(SCRIPT ${MODULE_NAME}ModuleTest.py)
from .models import VNetModel
from .ops import DistMaps3D, ScaleLayer, BatchImageNormalize3D, SigmoidForPred
from .predictor import BasePredictor, Click
from .preprocessing import *
import os
import sys
sys.path.append(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "./.."))
import paddle
import paddle.nn as nn
import numpy as np
from inference.ops import DistMaps3D, ScaleLayer, BatchImageNormalize3D
class ISModel3D(nn.Layer):
def __init__(
self,
use_rgb_conv=True,
with_aux_output=False,
norm_radius=2,
use_disks=False,
cpu_dist_maps=False,
clicks_groups=None,
with_prev_mask=False, # True
use_leaky_relu=False,
binary_prev_mask=False,
conv_extend=False,
norm_layer=nn.BatchNorm3D,
norm_mean_std=(
[0.00040428873, ],
[0.00059983705, ],
), ): # image.std(): [0.00053328] image.mean() [0.00023692])
super().__init__()
self.with_aux_output = with_aux_output
self.clicks_groups = clicks_groups
self.with_prev_mask = with_prev_mask
self.binary_prev_mask = binary_prev_mask
self.normalization = BatchImageNormalize3D(norm_mean_std[0],
norm_mean_std[1])
self.coord_feature_ch = 2
if clicks_groups is not None:
self.coord_feature_ch *= len(clicks_groups)
if self.with_prev_mask:
self.coord_feature_ch += 1 # 3
if use_rgb_conv:
rgb_conv_layers = [
nn.Conv3D(
in_channels=1 + self.coord_feature_ch,
out_channels=6 + self.coord_feature_ch,
kernel_size=1),
norm_layer(6 + self.coord_feature_ch),
nn.LeakyReLU(negative_slope=0.2)
if use_leaky_relu else nn.ReLU(),
nn.Conv3D(
in_channels=6 + self.coord_feature_ch,
out_channels=1,
kernel_size=1),
]
self.rgb_conv = nn.Sequential(*rgb_conv_layers)
elif conv_extend:
self.rgb_conv = None
self.maps_transform = nn.Conv3D(
in_channels=self.coord_feature_ch,
out_channels=64,
kernel_size=3,
stride=2,
padding=1)
else:
self.rgb_conv = None
mt_layers = [
nn.Conv3D(
in_channels=self.coord_feature_ch,
out_channels=16,
kernel_size=1),
nn.LeakyReLU(negative_slope=0.2)
if use_leaky_relu else nn.ReLU(),
nn.Conv3D(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1,
padding=1),
ScaleLayer(
init_value=0.05, lr_mult=1),
]
self.maps_transform = nn.Sequential(*mt_layers)
if self.clicks_groups is not None:
self.dist_maps = nn.LayerList()
for click_radius in self.clicks_groups:
self.dist_maps.append(
DistMaps3D(
norm_radius=click_radius,
spatial_scale=1.0,
cpu_mode=cpu_dist_maps,
use_disks=use_disks))
else:
self.dist_maps = DistMaps3D(
norm_radius=norm_radius,
spatial_scale=1.0,
cpu_mode=cpu_dist_maps,
use_disks=use_disks)
def forward(self, image, coord_features):
if self.rgb_conv is not None:
x = self.rgb_conv(paddle.concat(
(image, coord_features), axis=1)) # [B, 4, H, W, D] #
outputs = self.backbone_forward(x)
else:
coord_features = self.maps_transform(
coord_features) # [B, 3, H, W, D]
outputs = self.backbone_forward(image, coord_features)
outputs["instances"] = nn.functional.interpolate(
outputs["instances"],
size=paddle.shape(image)[2:], # [4, 20, 512, 512, 12]
mode="trilinear",
align_corners=True,
data_format="NCDHW", ) # image [4 , 1 , 512, 512, 12 ]
if self.with_aux_output:
outputs["instances_aux"] = nn.functional.interpolate(
outputs["instances_aux"],
size=paddle.shape(image)[2:],
mode="biltrilinearinear",
align_corners=True,
data_format="NCDHW", )
return outputs
def prepare_input(self, image):
prev_mask = None
if self.with_prev_mask:
prev_mask = paddle.slice(
image,
axes=[1, ],
starts=[1, ],
ends=[1000, ], )
image = paddle.slice(
image,
axes=[1, ],
starts=[0, ],
ends=[1, ], )
# prev_mask = image[:, 1:, :, :, :]
# image = image[:, :1, :, :, :]
if self.binary_prev_mask:
prev_mask = (prev_mask > 0.5).astype("float32")
image = self.normalization(image) # why?
return image, prev_mask
def backbone_forward(self, image, coord_features=None):
raise NotImplementedError
def get_coord_features(self, image, prev_mask, points):
coord_features = self.dist_maps(
image,
points) # [16, 1, 512, 512, 12], [16, 48, 4]. # [B, 2, H, W, D]
if prev_mask is not None:
coord_features = paddle.concat(
(prev_mask, coord_features), axis=1) # [B, 3, H, W, D]
return coord_features
def split_points_by_order(tpoints,
groups): # todo check if point have dimension problem
points = tpoints.numpy()
num_groups = len(groups)
bs = points.shape[0]
num_points = points.shape[1] // 2
groups = [x if x > 0 else num_points for x in groups]
group_points = [
np.full(
(bs, 2 * x, 3), -1, dtype=np.float32) for x in groups
]
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
for group_indx, group_size in enumerate(groups):
last_point_indx_group[:, group_indx, 1] = group_size
for bindx in range(bs):
for pindx in range(2 * num_points):
point = points[bindx, pindx, :]
group_id = int(point[2])
if group_id < 0:
continue
is_negative = int(pindx >= num_points)
if group_id >= num_groups or (
group_id == 0 and
is_negative): # disable negative first click
group_id = num_groups - 1
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
last_point_indx_group[bindx, group_id, is_negative] += 1
group_points[group_id][bindx, new_point_indx, :] = point
group_points = [
paddle.to_tensor(
x, dtype=tpoints.dtype) for x in group_points
]
return group_points
from paddleseg.utils import utils
class LUConv(nn.Layer):
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = nn.ELU() if elu else nn.PReLU(nchan)
self.conv1 = nn.Conv3D(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = nn.BatchNorm3D(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
def _make_nConv(nchan, depth, elu):
"""
Make depth number of layer(convbnrelu) and don't change the channel
Add Nonlinearity into the network
"""
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers)
class InputTransition(nn.Layer):
"""
Transfer the input into 16 channels + tiled input
"""
def __init__(self, in_channels, elu):
super(InputTransition, self).__init__()
self.num_features = 16
self.in_channels = in_channels
self.conv1 = nn.Conv3D(
self.in_channels, self.num_features, kernel_size=5, padding=2)
self.bn1 = nn.BatchNorm3D(self.num_features)
self.relu1 = nn.ELU() if elu else nn.PReLU(self.num_features)
def forward(self, x):
out = self.conv1(x)
repeat_rate = int(self.num_features / self.in_channels)
out = self.bn1(out)
x_tile = x.tile([1, repeat_rate, 1, 1, 1])
return self.relu1(paddle.add(out, x_tile))
class DownTransition(nn.Layer):
def __init__(self,
inChans,
nConvs,
elu,
dropout=False,
downsample_stride=(2, 2, 2),
kernel=(2, 2, 2)):
"""
1. double the output channel and downsample the input using down_conv(the kernel size can be changed)
2. add dropout by option
3. add nConvs layer to add linearity and add with original downsample one
"""
super(DownTransition, self).__init__()
outChans = 2 * inChans
self.if_dropout = dropout
self.down_conv = nn.Conv3D(
inChans, outChans, kernel_size=kernel, stride=downsample_stride)
self.bn1 = nn.BatchNorm3D(outChans)
self.relu1 = nn.ELU() if elu else nn.PReLU(outChans)
self.relu2 = nn.ELU() if elu else nn.PReLU(outChans)
self.dropout = nn.Dropout3D()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.dropout(down) if self.if_dropout else down
out = self.ops(out)
out = paddle.add(out, down)
out = self.relu2(out)
return out
class UpTransition(nn.Layer):
def __init__(
self,
inChans,
outChans,
nConvs,
elu,
dropout=False,
dropout2=False,
upsample_stride_size=(2, 2, 2),
kernel=(2, 2, 2), ):
super(UpTransition, self).__init__()
"""
1. Add dropout to input and skip input optionally (generalization)
2. Use Conv3DTranspose to upsample (upsample)
3. concate the upsampled and skipx (multi-leval feature fusion)
4. Add nConvs convs and residually add with result of step(residual + nonlinearity)
"""
self.up_conv = nn.Conv3DTranspose(
inChans,
outChans // 2,
kernel_size=kernel,
stride=upsample_stride_size)
self.bn1 = nn.BatchNorm3D(outChans // 2)
self.relu1 = nn.ELU() if elu else nn.PReLU(outChans // 2)
self.relu2 = nn.ELU() if elu else nn.PReLU(outChans)
self.if_dropout = dropout
self.if_dropout2 = dropout2
self.dropout1 = nn.Dropout3D()
self.dropout2 = nn.Dropout3D()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.dropout1(x) if self.if_dropout else x
skipx = self.dropout2(skipx) if self.if_dropout2 else skipx
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = paddle.concat((out, skipx), 1)
out = self.ops(xcat)
out = self.relu2(paddle.add(out, xcat))
return out
class OutputTransition(nn.Layer):
def __init__(self, in_channels, num_classes, elu):
"""
conv the output down to channels as the desired classesv
"""
super(OutputTransition, self).__init__()
self.conv1 = nn.Conv3D(
in_channels, num_classes, kernel_size=5, padding=2)
self.bn1 = nn.BatchNorm3D(num_classes)
self.relu1 = nn.ELU() if elu else nn.PReLU(num_classes)
self.conv2 = nn.Conv3D(num_classes, num_classes, kernel_size=1)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
return out
class VNet(nn.Layer):
"""
Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
"""
def __init__(
self,
elu=False,
in_channels=1,
num_classes=2,
pretrained=None,
kernel_size=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
stride_size=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), ):
super().__init__()
self.best_loss = 1000000
self.num_classes = num_classes
self.in_channels = in_channels
self.in_tr = InputTransition(in_channels, elu=elu)
self.down_tr32 = DownTransition(
16, 1, elu, downsample_stride=stride_size[0], kernel=kernel_size[0])
self.down_tr64 = DownTransition(
32, 2, elu, downsample_stride=stride_size[1], kernel=kernel_size[1])
self.down_tr128 = DownTransition(
64,
3,
elu,
dropout=True,
downsample_stride=stride_size[2],
kernel=kernel_size[2])
self.down_tr256 = DownTransition(
128,
2,
elu,
dropout=True,
downsample_stride=stride_size[3],
kernel=kernel_size[3])
self.up_tr256 = UpTransition(
256,
256,
2,
elu,
dropout=True,
dropout2=True,
upsample_stride_size=stride_size[3],
kernel=kernel_size[3])
self.up_tr128 = UpTransition(
256,
128,
2,
elu,
dropout=True,
dropout2=True,
upsample_stride_size=stride_size[2],
kernel=kernel_size[2])
self.up_tr64 = UpTransition(
128,
64,
1,
elu,
upsample_stride_size=stride_size[1],
kernel=kernel_size[1])
self.up_tr32 = UpTransition(
64,
32,
1,
elu,
upsample_stride_size=stride_size[0],
kernel=kernel_size[0])
self.out_tr = OutputTransition(32, num_classes, elu)
self.pretrained = pretrained
self.init_weight()
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
def forward(self, x, additional_features): # [4, 1, 512, 512, 12]
x = self.in_tr(x) # dropout cause a lot align problem
if additional_features is not None: # todo check shape [B, 16, H, W, D] # [4, 16, 512, 512, 12] #
x = x + additional_features
out32 = self.down_tr32(x) # [4, 32, 256, 256, 9]
out64 = self.down_tr64(out32) # [4, 64, 128, 128, 8]
out128 = self.down_tr128(out64) # [4, 128, 64, 64, 4]
out256 = self.down_tr256(out128) # [4, 256, 32, 32, 2]
out = self.up_tr256(out256, out128) # [4, 256, 64, 64, 4]
out = self.up_tr128(out, out64) # [4, 128, 128, 128, 8]
out = self.up_tr64(out, out32) # [4, 64, 256, 256, 9]
out = self.up_tr32(out, x) # [4, 32, 512, 512, 12]
out = self.out_tr(out) # [4, num_classes, 512, 512, 12]
return out
class VNetModel(ISModel3D):
# @serialize
def __init__(self,
elu=False,
in_channels=1,
num_classes=2,
pretrained=None,
kernel_size=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
stride_size=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)),
norm_layer=nn.BatchNorm3D,
**kwargs):
super().__init__(norm_layer=norm_layer, **kwargs)
self.feature_extractor = VNet(
elu=elu,
in_channels=in_channels,
num_classes=num_classes,
pretrained=pretrained,
kernel_size=kernel_size,
stride_size=stride_size, ) # diff: 去除了backbone mult,因为没有backbone
def backbone_forward(self, image, coord_features=None):
backbone_features = self.feature_extractor(
image, coord_features) # todo :增加对点特征的融合
return {
"instances": backbone_features,
"instances_aux": backbone_features,
} # result: 直接输出最后多少类别的分类tensor # [4, num_classes , 512, 512, 12]
import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F
class BaseTransform(object):
def __init__(self):
self.image_changed = False
def transform(self, image_nd, clicks_lists):
raise NotImplementedError
def inv_transform(self, prob_map):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
class SigmoidForPred(BaseTransform):
def transform(self, image_nd, clicks_lists):
return image_nd, clicks_lists
def inv_transform(self, prob_map):
return F.sigmoid(prob_map)
def reset(self):
pass
def get_state(self):
return None
def set_state(self, state):
pass
class BatchImageNormalize3D: # 标准化 均值为0,方差为1
def __init__(self, mean, std):
self.mean = paddle.to_tensor(
np.array(mean)[np.newaxis, :, np.newaxis, np.newaxis,
np.newaxis]).astype("float32")
self.std = paddle.to_tensor(
np.array(std)[np.newaxis, :, np.newaxis, np.newaxis,
np.newaxis]).astype("float32")
def __call__(self, tensor):
tensor = (tensor - self.mean) / self.std
return tensor
class ScaleLayer(nn.Layer):
def __init__(self, init_value=1.0, lr_mult=1):
super().__init__()
self.lr_mult = lr_mult
self.scale = self.create_parameter(
shape=[1],
dtype="float32",
default_initializer=nn.initializer.Constant(init_value / lr_mult))
def forward(self, x):
scale = paddle.abs(self.scale * self.lr_mult)
return x * scale
class DistMaps3D(nn.Layer):
def __init__(self,
norm_radius,
spatial_scale=1.0,
cpu_mode=False,
use_disks=False): # (1, 1.0, False, True)
super(DistMaps3D, self).__init__()
self.spatial_scale = spatial_scale
self.norm_radius = norm_radius
self.cpu_mode = cpu_mode
self.use_disks = use_disks
if self.cpu_mode:
from util.cython import get_dist_maps
self._get_dist_maps = get_dist_maps
def get_coord_features(self, points, batchsize, rows, cols,
layers): # [B, num_points*2, 4]
if self.cpu_mode:
coords = []
for i in range(batchsize):
norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius
coords.append(
self._get_dist_maps(points[i].numpy().astype("float32"),
rows, cols, norm_delimeter))
coords = paddle.to_tensor(np.stack(
coords, axis=0)).astype("float32")
else:
num_points = points.shape[1] // 2 # [1, 2, 4]
points = points.reshape(
[-1, paddle.shape(points)[2]]) # [B*num_points*2, 4]
points, points_order = paddle.split(
points, [3, 1], axis=1) # [2, 3]
# [B*num_points*2, 3], [B*num_points*2, 1]
invalid_points = paddle.max(points, axis=1, keepdim=False) < 0
row_array = paddle.arange(
start=0, end=rows, step=1, dtype="float32")
col_array = paddle.arange(
start=0, end=cols, step=1, dtype="float32")
layer_array = paddle.arange(
start=0, end=layers, step=1, dtype="float32")
coord_rows, coord_cols, coor_layers = paddle.meshgrid(
row_array,
col_array,
layer_array # [512, 512, 12]
) # len is 3 [rows, cols, layers]
coords = paddle.unsqueeze(
paddle.stack(
[coord_rows, coord_cols, coor_layers], axis=0),
axis=0).tile( # [B*num_points*2, 3, rows, cols, layers]
[paddle.shape(points)[0], 1, 1, 1,
1]) # [B*num_points*2 | 768, 3, 512, 512, 12] # repeat
add_xy = (points * self.spatial_scale).reshape(
[points.shape[0], points.shape[1], 1, 1, 1])
# [B*num_points*2, 3, 1, 1, 1]
# 所有的坐标组合,减去point的数值,只有point对应位置为0,其他相近的也小 [B*num_points*2, 3, rows, cols, layers]
coords = coords - add_xy # [B*num_points*2, 3, 512, 512, 12]
if not self.use_disks:
coords = coords / (self.norm_radius * self.spatial_scale)
coords = coords * coords # [B*num_points*2, 3, 512, 512, 12] 取平方
coords[:, 0] += coords[:, 1] + coords[:, 2]
coords = coords[:, :1] # [B*num_points*2, 1, rows, cols, layers]
# [B*2, num_points, 1, rows, cols, layers]
coords = coords.reshape([-1, num_points, 1, rows, cols, layers])
# [B*2, 1, 512, 512, 12] 所有point中最小的
coords = paddle.min(coords, axis=1)
coords = coords.reshape([-1, 2, rows, cols, layers])
# [B, 2, rows, cols, layers] [B, 2, 512, 512, 12]
if self.use_disks:
coords = (coords <=
(self.norm_radius * self.spatial_scale)**2).astype(
"float32") # 只取较小的数值对应的特征
else:
coords = paddle.tanh(paddle.sqrt(coords) * 2)
return coords
def forward(self, x, coords): # [16, 1, 512, 512, 12], [16, 48, 4]
batchsize = paddle.shape(x)[0]
rows, cols, layers = paddle.shape(x)[2:5]
return self.get_coord_features(coords, batchsize, rows, cols, layers)
import os
import sys
import logging
logging.getLogger().setLevel(logging.ERROR)
import numpy as np
from copy import deepcopy
sys.path.append(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "./.."))
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.inference import create_predictor, Config
from inference.ops import DistMaps3D, ScaleLayer, BatchImageNormalize3D, SigmoidForPred
class Click:
def __init__(self, is_positive, coords, indx=None):
if coords is None or is_positive is None:
raise ValueError(
"The coord is {}, is_positive is {} and one of them is None, but none of them should be."
)
self.coords = coords
self.is_positive = is_positive
self.index = None
@property
def coords_and_indx(self):
return (*self.coords, )
def copy(self, **kwargs):
self_copy = deepcopy(self)
for k, v in kwargs.items():
setattr(self_copy, k, v)
return self_copy
class BasePredictor(object):
def __init__(self,
model_path,
param_path,
net_clicks_limit=None,
with_mask=True,
norm_radius=2,
spatial_scale=1.0,
device="gpu",
enable_mkldnn=False,
**kwargs):
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.prev_prediction = None
self.model_indx = 0
self.click_models = None
self.net_state_dict = None
self.with_prev_mask = with_mask
self.device = device
self.enable_mkldnn = enable_mkldnn
if not paddle.in_dynamic_mode():
paddle.disable_static()
self.normalization = BatchImageNormalize3D(
[0.00040428873, ],
[0.00059983705, ], )
self.transforms = [SigmoidForPred()] # apply sigmoid after pred
# !! Todo Set the radius and spatial_scale here
self.dist_maps = DistMaps3D(
norm_radius=norm_radius,
spatial_scale=spatial_scale,
cpu_mode=False,
use_disks=True)
# init predictor config
self.pred_cfg = Config(model_path, param_path)
self.pred_cfg.disable_glog_info()
self.pred_cfg.enable_memory_optim()
self.pred_cfg.switch_ir_optim(True)
if self.device == "cpu":
self._init_cpu_config()
else:
self._init_gpu_config()
self.predictor = create_predictor(self.pred_cfg)
def _init_gpu_config(self):
logging.info("Use NVIDIA GPU")
self.pred_cfg.enable_use_gpu(100, 0)
def _init_cpu_config(self):
logging.info("Use x86 CPU")
self.pred_cfg.disable_gpu()
if self.enable_mkldnn:
logging.info("Use MKLDNN")
# cache 10 different shapes for mkldnn
# self.pred_cfg.set_mkldnn_cache_capacity(10) # cannot use on MAC
self.pred_cfg.enable_mkldnn()
self.pred_cfg.set_cpu_math_library_num_threads(10)
def set_input_image(self, image): # (1, 12, 512, 512)
# image is np array or other format including scalar,tuple,list,and paddle.Tensor
self.original_image = paddle.to_tensor(image).astype(
"float32") / 255 # 1, 512, 512, 12]
for transform in self.transforms:
transform.reset()
if len(self.original_image.shape) == 4:
self.original_image = self.original_image.unsqueeze(
0) # (1, 1, 12, 512, 512)
# 默认 concate 一个全0的mask作为prev mask
self.prev_prediction = paddle.zeros_like(
self.original_image[:, :1, :, :, :])
if not self.with_prev_mask:
self.prev_edge = paddle.zeros_like(self.original_image[:, :
1, :, :, :])
def get_prediction_noclicker(self, clicker, prev_mask=None):
clicks_list = clicker.get_clicks() # one click a time todo:累计多个点
input_image = self.original_image # [1, 1, 512, 512, 12]
if prev_mask is None:
if not self.with_prev_mask:
prev_mask = self.prev_edge
else:
prev_mask = self.prev_prediction
input_image = paddle.concat(
[input_image, prev_mask], axis=1) # [1, 2, 512, 512, 12]
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list])
pred_logits = self._get_prediction(image_nd, clicks_lists,
is_image_changed)
pred_logits = paddle.to_tensor(pred_logits) # [1, 1, 512, 512, 12]
for t in reversed(self.transforms): # conv as the final output
pred_logits = t.inv_transform(pred_logits)
self.prev_prediction = pred_logits
return pred_logits.numpy()[0, 0]
def get_prediction(self, clicker, prev_mask=None):
clicks_list = clicker.get_clicks() #
input_image = self.original_image
if prev_mask is None:
if not self.with_prev_mask:
prev_mask = self.prev_edge
else:
prev_mask = self.prev_prediction
input_image = paddle.concat([input_image, prev_mask], axis=1)
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list])
pred_logits = self._get_prediction(image_nd, clicks_lists,
is_image_changed)
pred_logits = paddle.to_tensor(pred_logits)
prediction = F.interpolate(
pred_logits,
mode="trilinear",
align_corners=True,
size=image_nd.shape[2:],
data_format="NCDHW")
for t in reversed(self.transforms):
if pred_edges is not None:
edge_prediction = t.inv_transform(edge_prediction)
self.prev_edge = edge_prediction
prediction = t.inv_transform(prediction)
self.prev_prediction = prediction
return prediction.numpy()[0, 0]
def prepare_input(self, image):
prev_mask = image[:, 1:, :, :, :]
image = image[:, :1, :, :, :]
image = self.normalization(image)
return image, prev_mask
def get_coord_features(self, image, prev_mask, points):
coord_features = self.dist_maps(image, points) # [1, 2, 512, 512, 12]
if prev_mask is not None:
coord_features = paddle.concat(
(prev_mask, coord_features), axis=1) # [1, 3, 512, 512, 12]
return coord_features
def _get_prediction(
self, image_nd, clicks_lists,
is_image_changed): # what is the click? click on the right place?
input_names = self.predictor.get_input_names()
self.input_handle_1 = self.predictor.get_input_handle(input_names[0])
self.input_handle_2 = self.predictor.get_input_handle(input_names[1])
points_nd = self.get_points_nd(clicks_lists) # 一个正点,一个负点
image, prev_mask = self.prepare_input(image_nd)
coord_features = self.get_coord_features(image, prev_mask, points_nd)
image = image.numpy().astype("float32")
coord_features = coord_features.numpy().astype("float32")
# logging.info("coord_features.shape, image.shape", coord_features.shape, image.shape, prev_mask.shape)
self.input_handle_1.copy_from_cpu(image)
self.input_handle_2.copy_from_cpu(coord_features)
self.predictor.run()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
return output_data
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
logging.info(
"check_list",
clicks_lists, )
num_pos_clicks = [
sum(x.is_positive for x in clicks_list)
for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list
if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)
) * [(-1, -1, -1, -1)]
neg_clicks = [
click.coords_and_indx for click in clicks_list
if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)
) * [(-1, -1, -1, -1)]
total_clicks.append(pos_clicks + neg_clicks)
return paddle.to_tensor(total_clicks)
def get_states(self):
return {
"transform_states": self._get_transform_states(),
"prev_prediction": self.prev_prediction,
}
def set_states(self, states):
self._set_transform_states(states["transform_states"])
self.prev_prediction = states["prev_prediction"]
import numpy as np
import SimpleITK as sitk
def resampleImage(refer_image,
out_size,
out_spacing=None,
interpolator=sitk.sitkLinear):
# 根据输出图像,对SimpleITK 的数据进行重新采样。重新设置spacing和shape
if out_spacing is None:
out_spacing = tuple((refer_image.GetSize() / np.array(out_size)) *
refer_image.GetSpacing())
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(refer_image)
resampler.SetSize(out_size)
resampler.SetOutputSpacing(out_spacing)
resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
resampler.SetInterpolator(interpolator)
return resampler.Execute(refer_image), out_spacing
def crop_wwwc(sitkimg, max_v, min_v):
# 对SimpleITK的数据进行窗宽窗位的裁剪,应与训练前对数据预处理时一致
intensityWindow = sitk.IntensityWindowingImageFilter()
intensityWindow.SetWindowMaximum(max_v)
intensityWindow.SetWindowMinimum(min_v)
return intensityWindow.Execute(sitkimg)
def GetLargestConnectedCompont(binarysitk_image):
# 最大连通域提取,binarysitk_image 是掩膜
cc = sitk.ConnectedComponent(binarysitk_image)
stats = sitk.LabelIntensityStatisticsImageFilter()
stats.SetGlobalDefaultNumberOfThreads(8)
stats.Execute(cc, binarysitk_image) # 根据掩膜计算统计量
# stats.
maxlabel = 0
maxsize = 0
for l in stats.GetLabels(): # 掩膜中存在的标签类别
size = stats.GetPhysicalSize(l)
if maxsize < size: # 只保留最大的标签类别
maxlabel = l
maxsize = size
labelmaskimage = sitk.GetArrayFromImage(cc)
outmask = labelmaskimage.copy()
if len(stats.GetLabels()):
outmask[labelmaskimage == maxlabel] = 255
outmask[labelmaskimage != maxlabel] = 0
return outmask
import os
import sys
sys.path.append(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "./.."))
import yaml
import paddle
import paddle.nn as nn
from paddleseg.utils import logger
from inference.models import VNetModel
def main():
model = VNetModel(
elu=False,
in_channels=1,
num_classes=1,
pretrained="/ssd2/tangshiyu/Code/EISeg-3D/experiments/3D_interseg/mrispineseg_vnet/174/checkpoints/049.pdparams", # "pretrained_models/vnet_model.pdparams",
kernel_size=[[2, 2, 4], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
stride_size=[[2, 2, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2]],
with_aux_output=False,
use_leaky_relu=True,
use_rgb_conv=False,
use_disks=True,
norm_radius=2,
with_prev_mask=True, )
model.set_dict(paddle.load("model_checkpoints/039.pdparams"))
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # on or off did not change
model.eval()
print("Loaded trained params of model successfully")
new_net = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, None, None, None], dtype="float32"),
paddle.static.InputSpec(
shape=[None, 3, None, None, None],
dtype="float32"), # 16, 48, 4
], )
paddle.jit.save(new_net, "output_cpu/static_Vnet_model")
yml_file = os.path.join("output_cpu/static_VNet_model", "vnet_deploy.yaml")
with open(yml_file, "w") as file:
data = {
"Deploy": {
"model": "static_Vnet_model.pdmodel",
"params": "static_Vnet_model.pdiparams"
}
}
yaml.dump(data, file)
logger.info("Model is saved in {}".format("output_cpu"))
if __name__ == "__main__":
main()
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
简体中文 | [English](README_en.md)
<div align="center">
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188449455-cd4e4099-6e70-44ca-b8de-57bab04c187c.png" align="middle" width = 500" />
</p>
**专注用户友好、高效、智能的3D医疗图像标注平台** <img src="https://user-images.githubusercontent.com/34859558/188409382-467c4c45-df5f-4390-ac40-fa24149d4e16.png" width="30"/>
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
![python version](https://img.shields.io/badge/python-3.6+-orange.svg)
![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
</div>
## <img src="https://user-images.githubusercontent.com/34859558/188422593-4bc47c72-866a-4374-b9ed-1308c3453165.png" width="30"/> 简介
3D 医疗数据标注是训练 3D 图像分割模型的重要一环,但 3D 医疗数据标注依赖专业人士进行费时费力的手工标注。 标注效率的低下导致了大规模标注数据的缺乏,从而严重阻碍了医疗AI的发展。为了解决这个问题,我们推出了基于交互式分割的3D医疗图像智能标注平台 EISeg-Med3D。
EISeg-Med3D 是一个用于智能医学图像分割的 3D Slicer 插件,通过使用训练的交互式分割 AI 模型来进行交互式医学图像标注。它安装简单、使用方便,结合高精度的预测模型,可以获取比手工标注**数十倍**的效率提升。目前我们的医疗标注提供了在指定的 [MRI 椎骨数据](https://aistudio.baidu.com/aistudio/datasetdetail/81211)上的使用体验,如果有其他数据上的3D标注需求,可以[联系我们](https://github.com/PaddlePaddle/PaddleSeg/issues/new/choose)
<div align="center">
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188415269-10526530-0415-4632-8223-0e5d755db29c.gif" align="middle" width = 900"/>
</p>
</div>
## <img src="https://user-images.githubusercontent.com/34859558/188419267-bd117697-7456-4c72-8cbe-1272264d4fe4.png" width="30"/> 特性
* **高效**:每个类别只需**数次点击**直接生成3d分割结果,从此告别费时费力的手工标注。
* **准确**:点击 3 点单物体 mIOU 即可达到 **0.85**,配合搭载机器学习算法和手动标注的标注编辑器,精度 100% 不是梦。
* **便捷**:三步轻松安装;标注结果、进度自动保存;标注结果透明度调整提升标注准确度;用户友好的界面交互,让你标注省心不麻烦。
*************
## 目录结构
0. [最新消息](##最新消息)
1. [EISeg-Med3D 模型介绍](##EISeg-Med3D模型介绍)
2. [使用指南](##使用指南)
3. [TODO](##TODO)
4. [License](##License)
5. [致谢](##致谢)
## <img src="https://user-images.githubusercontent.com/34859558/190043516-eed25535-10e8-4853-8601-6bcf7ff58197.png" width="30"/> 最新消息
- [2022-09] EISeg-Med3D 正式发布,包含在指定椎骨数据上的高精度模型的**用户友好、高效、智能的3D医疗图像标注平台**
## <img src="https://user-images.githubusercontent.com/34859558/190049708-7a1cee3c-322b-4263-9ed0-23051825b1a6.png" width="30"/> EISeg-Med3D 模型介绍
EISeg-Med3D模型结构如下图所示,我们创新性地将3D模型引入医疗交互式分割中,并修改 RITM 的采点模块和生成点击特征模块和3D数据兼容,从而直接进行3D医学图像的标注,从模型层面在2D标注的基础上实现标注的**更精准,更高效**
整体模型包含点击生成模块、点击特征生成模块、点击特征和输入图像融合和分割模个部分:
* 点击生成模块 3D click sampler:直接基于 3D 标签数据进行正负点采样,其中正点为3D目标体中心位置随机点,负点为3D目标体边缘随机采点。
* 点击特征生成模块 3D feature extractor:在生成点击之后,为了扩大点击的影响范围,通过disk的形式在原有点击的基础上生成半径R的圆球,扩大特征覆盖范围。
* 点击特征和输入图像融合:将输入图像和生产的点击特征经过卷积块进行重新映射和进行相加融合,从而网络同时获取图像和点击信息,并对特定区域进行3D分割获得标注结果。
* 分割模型:分割模型沿用 3D 分割模型Vnet,在图像和点击信息综合下生成如图所示的预测结果,并在Dice损失和CE损失的约束下逼近真实结果。从而在预测阶段,输入图像和指定点击后基于点击目标生成期望的标注结果。
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/190861789-793bd9f3-17a8-49d6-a2a7-bce82696d28e.png" width="80.6%" height="20%">
<p align="center">
EISeg-Med3D模型
</p>
</p>
## <img src="https://user-images.githubusercontent.com/34859558/188439970-18e51958-61bf-4b43-a73c-a9de3eb4fc79.png" width="30"/> 使用指南
EISeg-Med3D 的使用整体流程如下图所示,我们将按照环境安装、模型下载和使用步骤三部分说明,其中使用步骤也可以参见简介中的视频。
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187884472-32e6dd36-be7b-4b32-b5d1-c0ccd743e1ca.png" width="60.6%" height="20%">
<p align="center">
整体使用流程
</p>
</p>
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187884776-470195d6-46d1-4e2b-8403-0cb320f185ec.png" width="80.6%" height="60%">
<p align="center">
智能标注模块流程
</p>
</p>
### 1. 环境安装
1. 下载并安装3D slicer软件:[3D slicer 官网](https://www.slicer.org/)
2. 下载 EISeg-Med3D 代码:
```bash
git clone https://github.com/PaddlePaddle/PaddleSeg.git
```
3. 安装Paddle,先在slicer的python interpreter中找到解释器名称,随后在cmd中参考[快速安装文档](https://www.paddlepaddle.org.cn/install/quick)安装PaddlePaddle。
```bash
import sys
import os
sys.executable # "D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe"
```
进入Windows 下CMD,比如Windows、CUDA 11.1,安装GPU版本,执行如下命令:
```bash
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip install paddlepaddle-gpu==2.3.1.post111 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
```
<summary><b> 常见问题 </b></summary>
1. 安装PaddlePaddle之后出现FileNotFoundError:
<details>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/189288387-4773c35a-ac8e-421d-bfed-2264ac57cda5.png" width="70.6%" height="20%">
</p>
解决方式:进入到报错位置所在的 subprocess.py, 修改Popen类的属性 shell=True 即可。
</details>
2. ERROR: No .egg-info directory found in xxx:
<details>
解决方式:参考 https://github.com/PaddlePaddle/PaddleSeg/issues/2718,执行以下指令能成功进行安装。
```python
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip uninstall setuptools
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip install paddleseg simpleitk
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip install setuptools
```
</details>
3. 点击确认load module后,提示 One or more requested modules and/or depandencoes may not have been loaded。
<details>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699311-8a12e976-904f-46e0-8bbf-9d9f0290393d.png" width="30.6%" height="20%">
</p>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699334-a31e6827-b907-456a-a686-1c1f3ac6014d.png" width="50.6%" height="20%">
</p>
解决方式:有部分需要import的库没有安装,例如paddle/paddleseg/simpleitk等,使用第二步的步骤进行安装后重启slicer并重新导入。
</details>
4. Fail to open extention: xxx/PaddleSeg/EISeg/med3d/EISefMed3D
<details>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699489-6b75d9f2-5cf6-42d2-9d74-17894fc3e00b.png" width="50.6%" height="20%">
</p>
解决方式:需要选择的加载路径为xxx/PaddleSeg/EISeg/med3d/ 而不是xxx/PaddleSeg/EISeg/med3d/EISefMed3D
</details>
### 2. 模型、数据下载
目前我们提供在下列模型和数据上的试用体验,可以下载表格中模型和数据到指定目录,并将模型和数据进行解压缩操作用于后续加载:
<p align="center">
| 数据 | 模型 | 下载链接 |
|:-:|:-:|:-:|
| MRI椎骨数据 | 交互式 Vnet |[模型](链接: https://pan.baidu.com/s/1vu0ZIbGumlFvRlMGbMvWAg )-pw: dt8q \| [椎骨数据](https://aistudio.baidu.com/aistudio/datasetdetail/81211)|
</p>
### 3. 使用步骤
#### 0. 双击打开 3D Slicer
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/190042214-391fbdc4-a007-42d2-9019-f1ff3d97b6eb.png" width="20.6%" height="20%">
</p>
#### 1. 加载插件
* 找到 Extension wizard 插件:
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188458289-b59dc5e3-34eb-4d40-b18b-ce0b35c066c6.png" width="60.6%" height="20%">
</p>
* 点击 Select Extension,并选择到 PaddleSeg/EISeg/med3d 目录,并点击加载对应模块,等待 Slicer 进行加载。
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188458463-066ff0b6-ff80-4d0d-aca0-3b3b12f710ef.png" width="60.6%" height="20%">
</p>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699311-8a12e976-904f-46e0-8bbf-9d9f0290393d.png" width="60.6%" height="20%">
</p>
* 加载完后,切换到 EISegMed3D模块。
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188458684-46465fed-fdde-43dd-a97c-5da7678f3f99.png" width="60.6%" height="20%">
</p>
#### 2. 加载模型
*```Model Settings```中加载保存在本地的模型,点击```Model Path```路径选择框后面的```...```的按钮,选择后缀名为```.pdodel```的本地文件,点击```Param Path```路径选择框后面的```...```的按钮,选择后缀名为```.pdiparams```的本地文件。
* 点击```Load Static Model```按钮,此时会有弹窗提示```Sucessfully loaded model to gpu!```,表示模型已经加载进来。
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187881886-e4d99fb4-c697-48a5-8cd7-a5ab83c7791d.PNG" width="70.6%" height="20%">
</p>
#### 3. 加载图像
* 点击```Data Folder```后面的按钮,选择待标注的医学图像文件所在路径后,会自动把该路径下的所有图像全部加载,此时可以在```Progress```中查看加载进来的所有图像以及当前已标注状态。
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882370-6f9a8f21-8a96-4a4c-8451-18d6e608f7e4.PNG" width="70.6%" height="20%">
</p>
#### 4. 开始标注
*```Segment Editor```中点击```Add/Remove```按钮便可自行添加标签或是删除标签,添加标签时会有默认命名,也可以双击标签自行给标签命名。
* 添加标签完毕后即可选中某个标签,点击```Positive Point```或是```Negative Point```后的按钮即可开始交互式标注。
* 点击```Finish Segment```按钮,即可结束当前所选标签下的标注,此时可点击左侧的橡皮擦等工具对标注结果进行精修。或者可重复以上步骤进行下一个对象的标注,否则可点击```Finish Scan```按钮,便会切换下一张图像。
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882400-8ee24469-6cb7-4c6a-acf8-df0e14e3f2a7.PNG" width="70.6%" height="20%">
</p>
* 关于精细修改标注工具的使用,详细可见[Slicer Segment editor](https://slicer.readthedocs.io/en/latest/user_guide/modules/segmenteditor.html)
#### 5. 切换图像
* 点击```Prev Scan```按钮可以切换上一张图像到当前视图框内。
* 点击```Next Scan```按钮可以切换下一张图像到当前视图框内。
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882440-e1c3cc03-b79e-4ad8-9987-20af42c9ae01.PNG" width="70.6%" height="20%">
</p>
#### 6. 查看标注进程
*```Progress```中的```Annotation Progress```后面的进度条中可以查看当前加载进来的图像标注进程。
* 双击```Annotation Progress```下方表格中某一张图像文件名,便可以自动跳转到所选图像。
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882460-0eb0fc86-d9d7-4733-b812-85c62b1b9281.PNG" width="70.6%" height="20%">
</p>
<!-- </details> -->
## <img src="https://user-images.githubusercontent.com/34859558/190046674-53e22678-7345-4bf1-ac0c-0cc99718b3dd.png" width="30"/> TODO
未来,我们想在这几个方面来继续发展EISeg-Med3D,欢迎加入我们的开发者小组。
- [ ] 在更大的椎骨数据集上进行训练,获取泛化性能更好的标注模型。
- [ ] 开发在多个器官上训练的模型,从而获取能泛化到多器官的标注模型。
## <img src="https://user-images.githubusercontent.com/34859558/188446853-6e32659e-8939-4e65-9282-68909a38edd7.png" width="30"/> License
EISeg-Med3D 的 License 为 [Apache 2.0 license](LICENSE).
## <img src="https://user-images.githubusercontent.com/34859558/188446803-06c54d50-f2aa-4a53-8e08-db2253df52fd.png" width="30"/> 致谢
感谢 <a href="https://www.flaticon.com/free-icons/idea" title="idea icons"> Idea icons created by Vectors Market - Flaticon</a> 给我们提供了好看的图标
English | [简体中文](README.md)
<div align="center">
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188449455-cd4e4099-6e70-44ca-b8de-57bab04c187c.png" align="middle" width = 500" />
</p>
**A easy-to-use, efficient, smart 3D medical image annotation platform** <img src="https://user-images.githubusercontent.com/34859558/188409382-467c4c45-df5f-4390-ac40-fa24149d4e16.png" width="30"/>
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
![python version](https://img.shields.io/badge/python-3.6+-orange.svg)
![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
</div>
## <img src="https://user-images.githubusercontent.com/34859558/188422593-4bc47c72-866a-4374-b9ed-1308c3453165.png" width="30"/> Brief Introduciton
3D medical data annotation is an important part of training 3D image segmentation models and promotes disease diagnosis and treatment prediction, but 3D medical data annotation relies on time-consuming and laborious manual annotation by professionals. The low labeling efficiency leads to the lack of large-scale labeling data, which seriously hinders the development of medical AI. To solve this problem, we launched EISeg-Med3D, an intelligent annotation platform for 3D medical images based on interactive segmentation.
EISeg-Med3D is a 3D slicer extension for performing **E**fficient **I**nteractive **Seg**mentation on **Med**ical image in **3D** medical images. Users will guide a deep learning model to perform segmentation by providing positive and negative points. It is simple to install, easy to use and accurate, which can achieve ten times efficiency lift compares to manual labelling. At present, our medical annotation provides the try-on experience on the specified [MRI vertebral data](https://aistudio.baidu.com/aistudio/datasetdetail/81211). If there is a need for 3D annotation on other data, you can make a [contact](https://github.com/PaddlePaddle/PaddleSeg/issues/new/choose).
<div align="center">
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188415269-10526530-0415-4632-8223-0e5d755db29c.gif" align="middle" width = 900"/>
</p>
</div>
## <img src="https://user-images.githubusercontent.com/34859558/188419267-bd117697-7456-4c72-8cbe-1272264d4fe4.png" width="30"/> Feature
* **Efficient**:Each category only needs a few clicks to generate 3d segmentation results, ten times efficient compares to time-consuming and laborious manual annotation.
* **Accurate**:The mIOU can reach 0.85 with only 3 clicks. with the segmentation editor equipped with machine learning algorithm and manual annotation, 100% accuracy is right on your hand.
* **Convenient**:Install our plugin within three steps; labeling results and progress are automatically saved; the transparency of labeling results can be adjusted to improve labeling accuracy; user-friendly interface interaction makes labeling worry-free and hassle-free。
*************
## Contents
0. [News](##News)
1. [EISeg-Med3D Model Introduction](##EISeg-Med3DModelIntroduction)
2. [User Guide](##UserGuide)
3. [TODO](##TODO)
4. [License](##License)
5. [Thanks](##Thanks)
## <img src="https://user-images.githubusercontent.com/34859558/190043516-eed25535-10e8-4853-8601-6bcf7ff58197.png" width="30"/> 最新消息
- [2022-09] EISeg-Med3D is officially released, **a user-friendly, efficient and intelligent 3D medical image annotation platform** including high-precision models on specified vertebral data.
## <img src="https://user-images.githubusercontent.com/34859558/190049708-7a1cee3c-322b-4263-9ed0-23051825b1a6.png" width="30"/> EISeg-Med3D Model
The EISeg-Med 3D model structure is shown in the figure below. We innovatively introduce the 3D model into the medical interactive segmentation, and modify the point sampler module and the click feature extrator of RITM to be compatible with 3D data, so as to directly label 3D medical images. Compared with 2D interactive annotation on 3D images, our method is more acurate and more efficient.
The overall model includes two parts: click generation module, click feature generation module, click feature and input image fusion and segmentation model:
* Click generation module 3D click sampler: generate the positive and negative click through sampling on the 3D labelled data directly, where the positive point is a random point at the center of the 3D target segment, and the negative point is a random point at the edge of the 3D target segment.
* Click feature generation module 3D feature extractor: After the click is generated, a sphere with radius R is generated on the basis of the original click in the form of disk to expand the feature coverage.
* Fusion of click feature and input image: The input image and the generated click feature are remapped and fused through the convolution block, so that the network obtains information from both the image and the clicks, and performs 3D segmentation on a the assigned area to obtain the annotation result.
* Segmentation model: The segmentation model is the 3D segmentation model Vnet, generates the prediction results shown in the figure, and approximates the real results under the constraints of Dice loss and CE loss. Thus, in the prediction stage, input images and specified clicks generate the desired annotation results.
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/190861789-793bd9f3-17a8-49d6-a2a7-bce82696d28e.png" width="80.6%" height="20%">
<p align="center">
EISeg-Med3D Model
</p>
</p>
## <img src="https://user-images.githubusercontent.com/34859558/188439970-18e51958-61bf-4b43-a73c-a9de3eb4fc79.png" width="30"/> User Guide
The overall process of using EISeg-Med3D is shown in the figure below. We will introduce in the following three steps including environment installation, model and data download and user guide. The steps to use our platform can also be found in the video in the introduction.
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187884472-32e6dd36-be7b-4b32-b5d1-c0ccd743e1ca.png" width="60.6%" height="20%">
<p align="center">
The overall process
</p>
</p>
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187884776-470195d6-46d1-4e2b-8403-0cb320f185ec.png" width="80.6%" height="60%">
<p align="center">
The process of AI-assisted labelling
</p>
</p>
### 环境安装
1. Download and install 3D slicer:[Slicer website](https://www.slicer.org/)
2. Download code of EISeg-Med3D:
```bash
git clone https://github.com/PaddlePaddle/PaddleSeg.git
```
3. Find the slicer python executor in python interpreter of slicer. Then install PaddlePaddle refer to [install doc](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html) in windows cmd.
```bash
import sys
import os
sys.executable # "D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe"
```
In CMD, If you install on Windows with CUDA 11.1 GPU, follow the command here::
```bash
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip install paddlepaddle-gpu==2.3.1.post111 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
```
<summary><b> Common FAQ </b></summary>
1. FileNotFoundError when install PaddlePaddle:
<details>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/189288387-4773c35a-ac8e-421d-bfed-2264ac57cda5.png" width="70.6%" height="20%">
</p>
Solution:Find the subprocess.py that raise the error, change the attribute of Popen shell=True。
</details>
2. ERROR: No .egg-info directory found in xxx:
<details>
Solution:Please refer to https://github.com/PaddlePaddle/PaddleSeg/issues/2718. Execute the following code will solve the error.
```python
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip uninstall setuptools
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip install paddleseg simpleitk
"D:/xxxx/Slicer 5.0.3/bin/PythonSlicer.exe" -m pip install setuptools
```
</details>
3. When load module, the poped window says "One or more requested modules and/or depandencoes may not have been loaded".
<details>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699311-8a12e976-904f-46e0-8bbf-9d9f0290393d.png" width="30.6%" height="20%">
</p>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699334-a31e6827-b907-456a-a686-1c1f3ac6014d.png" width="50.6%" height="20%">
</p>
Solution:You may forget to install some library we need to import, eg: paddle, paddleseg, simpleitk and etc. Please refer to step 1 and 2 to install and restart Slicer to import.
</details>
4. Fail to open extention: xxx/PaddleSeg/EISeg/med3d/EISefMed3D
<details>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699489-6b75d9f2-5cf6-42d2-9d74-17894fc3e00b.png" width="50.6%" height="20%">
</p>
Solution:The path you need to choose is "xxx/PaddleSeg/EISeg/med3d/" but not "xxx/PaddleSeg/EISeg/med3d/EISefMed3D"
</details>
### Model and Data Downloading
Currently we provide trial experience on the following models and data:
<p align="center">
| Data | Model | Links |
|:-:|:-:|:-:|
| MRI-spine | Interactive Vnet |[pdiparams](https://pan.baidu.com/s/1Dk-PqogeJOiaEGBse3kFOA)-pw: 6ok7 \| [pdmodel](https://pan.baidu.com/s/1daFrC1C2cwCmovvLj5n3QA)-pw: sg80 \| [Spine Data](https://aistudio.baidu.com/aistudio/datasetdetail/81211)|
</p>
### User Guide
#### 1. Load the Extension
* Locate Extension wizard:
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188458289-b59dc5e3-34eb-4d40-b18b-ce0b35c066c6.png" width="70.6%" height="20%">
</p>
* Click on "Select Extension",and choose PaddleSeg/EISeg/med3d directory, and click to load corresponding module.
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188458463-066ff0b6-ff80-4d0d-aca0-3b3b12f710ef.png" width="70.6%" height="20%">
</p>
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/204699311-8a12e976-904f-46e0-8bbf-9d9f0290393d.png" width="70.6%" height="20%">
</p>
* After loading, switch to EISegMed3D。
<p align="center">
<img src="https://user-images.githubusercontent.com/34859558/188458684-46465fed-fdde-43dd-a97c-5da7678f3f99.png" width="70.6%" height="20%">
</p>
#### 2. Load Model
* Load the downloaded model in ```Model Settings```:
click on the ```...``` button of ```Model Path```, choose local file of ```.pdodel``` suffix and load ```.pdiparams``` file in ```Param Path``` in the same way.
* Click on ```Load Static Model``` button. And ```Sucessfully loaded model to gpu!``` window will be prompt is the model is loaded successfully.
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187881886-e4d99fb4-c697-48a5-8cd7-a5ab83c7791d.PNG" width="70.6%" height="20%">
</p>
#### 3. Load Medical Data
* Click on the button behind ```Data Folder```, choose the folder that you saved your downloaded data. And all of the data under that folder will be loaded and you can see the labelling status of loaded data in ```Progress```.
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882370-6f9a8f21-8a96-4a4c-8451-18d6e608f7e4.PNG" width="70.6%" height="20%">
</p>
#### 4. Switch Between Images.
* Click on the ```Prev Scan``` button to see the previous image.
* Click on the ```Next Scan``` button to see the next image.
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882440-e1c3cc03-b79e-4ad8-9987-20af42c9ae01.PNG" width="70.6%" height="20%">
</p>
#### 5. Start to label
* Click ```Add/Remove``` in ```Segment Editor``` to add or remove the label. You can change the name of added label by double click the label item.
* Choose the label you want to label and click on the ```Positive Point``` or ```Negative Point``` to enter interactive label mode。
* Click on ```Finish Segment``` button to finish annotation of current segment, you can further edit the annotatioin using tools in segment editor or you can repeat previous step to label next category. If you finished the annotation on this case, you can click on the ```Finish Scan``` button.
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882400-8ee24469-6cb7-4c6a-acf8-df0e14e3f2a7.PNG" width="70.6%" height="20%">
</p>
* See [Slicer Segment editor](https://slicer.readthedocs.io/en/latest/user_guide/modules/segmenteditor.html) for using the tool in segment editor.
#### 6. Check Label Progress
* In ```Annotation Progress``` of ```Progress```, you can checkout the labelling progress of loaded images.
* Doule click on one of the image in the chart below ```Annotation Progress``` will jump to the corresponding image.
<p align="center">
<img src="https://user-images.githubusercontent.com/48357642/187882460-0eb0fc86-d9d7-4733-b812-85c62b1b9281.PNG" width="70.6%" height="20%">
</p>
<!-- </details> -->
## <img src="https://user-images.githubusercontent.com/34859558/190046674-53e22678-7345-4bf1-ac0c-0cc99718b3dd.png" width="30"/> TODO
In the future, we want to continue to develop EISeg-Med3D in these aspects, welcome to join our developer team.
- [ ] Work on larger vertebrae datasets and improve generality of our spine model.
- [ ] Develop models trained on multiple organs to obtain models that generalize to multiple organs.
## <img src="https://user-images.githubusercontent.com/34859558/188446853-6e32659e-8939-4e65-9282-68909a38edd7.png" width="30"/> License
EISeg-Med3D is released under the [Apache 2.0 license](LICENSE).
## <img src="https://user-images.githubusercontent.com/34859558/188446803-06c54d50-f2aa-4a53-8e08-db2253df52fd.png" width="30"/> Attribution
Thanks to <a href="https://www.flaticon.com/free-icons/idea" title="idea icons"> Idea icons created by Vectors Market - Flaticon</a> for facsinating icons.
[tool.black]
line-length = 120
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment