Commit 0d97cc8c authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
Pipeline #316 failed with stages
in 0 seconds
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
import os.path as osp
from functools import partial
import json
from distutils.util import strtobool
import webbrowser
from easydict import EasyDict as edict
from qtpy import QtGui, QtCore, QtWidgets
from qtpy.QtWidgets import QMainWindow, QMessageBox, QTableWidgetItem, QApplication
from qtpy.QtGui import QImage, QPixmap
from qtpy.QtCore import Qt, QByteArray, QVariant, QCoreApplication, QThread, Signal, QTimer
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.nn.functional as F
from eiseg import pjpath, __APPNAME__, logger
from widget import ShortcutWidget, PolygonAnnotation
from controller import InteractiveController
from ui import Ui_EISeg
import util
from util import COCO
from util import check_cn, normcase
import plugin.remotesensing as rs
from plugin.medical import med
from plugin.remotesensing import Raster
from plugin.n2grid import RSGrids, Grids, checkOpenGrid
from plugin.video import InferenceCore, overlay_davis
class APP_EISeg(QMainWindow, Ui_EISeg):
IDILE, ANNING, EDITING = 0, 1, 2
# IDILE:网络,权重,图像三者任一没有加载
# EDITING:多边形编辑,可以交互式,但是多边形内部不能点
# ANNING:交互式标注,只能交互式,不能编辑多边形,多边形不接hover
# 宫格标注背景颜色
GRID_COLOR = {
"idle": QtGui.QColor(255, 255, 255),
"current": QtGui.QColor(192, 220, 243),
"finised": QtGui.QColor(185, 185, 225),
"overlying": QtGui.QColor(51, 52, 227),
}
def __init__(self, parent=None):
super(APP_EISeg, self).__init__(parent)
self.settings = QtCore.QSettings(
osp.join(pjpath, "config/setting.txt"), QtCore.QSettings.IniFormat)
currentLang = self.settings.value("language")
layoutdir = Qt.RightToLeft if currentLang == "Arabic" else Qt.LeftToRight
self.setLayoutDirection(layoutdir)
# 初始化界面
self.setupUi(self)
# app变量
self._anning = False # self.status替代
self.isDirty = False # 是否需要保存
self.image = None # 可能先加载图片后加载模型,只用于暂存图片
self.predictor_params = {
"brs_mode": "NoBRS",
"with_flip": False,
"zoom_in_params": {
"skip_clicks": -1,
"target_size": (400, 400),
"expansion_ratio": 1.4,
},
"predictor_params": {
"net_clicks_limit": None,
"max_size": 800,
"with_mask": True,
},
}
self.controller = InteractiveController(
predictor_params=self.predictor_params,
prob_thresh=self.segThresh, )
self.video = InferenceCore()
self.video_images = None
self.video_masks = None
# self.controller.labelList = util.LabelList() # 标签列表
self.save_status = {
"gray_scale": True,
"pseudo_color": True,
"json": False,
"coco": True,
"cutout": True,
} # 是否保存这几个格式
self.outputDir = None # 标签保存路径
self.labelPaths = [] # 所有outputdir中的标签文件路径
self.imagePaths = [] # 文件夹下所有待标注图片路径
self.currIdx = 0 # 文件夹标注当前图片下标
self.origExt = False # 是否使用图片本身拓展名,防止重名覆盖
if self.save_status["coco"]:
self.coco = COCO()
else:
self.coco = None
self.colorMap = util.colorMap
if self.settings.value("cutout_background"):
self.cutoutBackground = [
int(c) for c in self.settings.value("cutout_background")
]
if len(self.cutoutBackground) == 3:
self.cutoutBackground += tuple([255])
else:
self.cutoutBackground = [0, 0, 128, 255]
if self.settings.value("cross_color"):
self.crossColor = [
int(c) for c in self.settings.value("cross_color")
]
else:
self.crossColor = [0, 0, 0, 127]
self.scene.setPenColor(self.crossColor)
# widget
self.dockWidgets = {
"model": [self.ModelDock],
"data": [self.DataDock],
"label": [self.LabelDock],
"seg": [self.SegSettingDock],
"rs": [self.RSDock],
"med": [self.MedDock],
"grid": [self.GridDock],
"video": [self.VideoDock],
"vseg": [self.VSTDock],
"3d": [self.TDDock],
}
# self.display_dockwidget = [True, True, True, True, False, False, False, False, False, False]
self.dockStatus = self.settings.value(
"dock_status", QVariant([]), type=list) # 所有widget是否展示
if len(self.dockStatus) != len(self.dockWidgets):
self.dockStatus = [True] * 4 + [False] * (len(self.dockWidgets) - 4)
self.settings.setValue("dock_status", self.dockStatus)
else:
self.dockStatus = [strtobool(s) for s in self.dockStatus]
self.layoutStatus = self.settings.value("layout_status",
QByteArray()) # 界面元素位置
self.recentModels = self.settings.value(
"recent_models", QVariant([]), type=list)
self.video_recentModels = self.settings.value(
"video_recent_models", QVariant([]), type=list)
self.recentFiles = self.settings.value(
"recent_files", QVariant([]), type=list)
self.config = util.parse_configs(osp.join(pjpath, "config/config.yaml"))
# 支持的图像格式
rs_ext = [".tif", ".tiff"]
img_ext = []
for fmt in QtGui.QImageReader.supportedImageFormats():
fmt = ".{}".format(fmt.data().decode())
if fmt not in rs_ext:
img_ext.append(fmt)
video_ext = [
".wmv",
".asf",
".asx",
".rm",
".rmvb",
".mp4",
".3gp",
".mov",
".m4v",
".avi",
".dat",
".mkv",
".flv",
".vob",
]
self.video_ext = video_ext
self.formats = [
img_ext, # 自然图像
[".dcm"], # 医学影像
rs_ext, # 遥感影像
video_ext, # 视频
]
# 遥感
self.raster = None
self.grid = None
self.rsRGB = [1, 1, 1] # 遥感索引
# 医疗参数
self.midx = 0 # 医疗切片索引
# 大图限制
self.thumbnail_min = 2000
# 初始化action
self.initActions()
# 更新近期记录
self.loadLayout() # 放前面
self.toggleWidget("all", warn=False)
self.updateModelMenu()
self.updateVideoModelMenu()
self.updateRecentFile()
# self.VideoDock.hide()
# 窗口
## 快捷键
self.ShortcutWidget = ShortcutWidget(self.actions, pjpath)
## 画布
self.scene.clickRequest.connect(self.canvasClick)
self.canvas.zoomRequest.connect(self.viewZoomed)
self.canvas.mousePosChanged.connect(self.scene.onMouseChanged)
self.annImage = QtWidgets.QGraphicsPixmapItem()
self.scene.addItem(self.annImage)
## 按钮点击
self.btnSave.clicked.connect(self.exportLabel) # 保存
self.listFiles.itemDoubleClicked.connect(
self.imageListClicked) # 标签列表点击
self.btnAddClass.clicked.connect(self.addLabel)
self.btnParamsSelect.clicked.connect(self.changeParam) # 模型参数选择
self.btn3DParamsSelect.clicked.connect(self.changePropgationParam)
self.cheWithMask.stateChanged.connect(self.chooseMode) # with_mask
self.btnPropagate.clicked.connect(self.on_propgation)
## 滑动
self.sldOpacity.valueChanged.connect(self.maskOpacityChanged)
self.sldClickRadius.valueChanged.connect(self.clickRadiusChanged)
self.sldThresh.valueChanged.connect(self.threshChanged)
# self.sldBrush.valueChanged.connect(self.brushChanged)
self.sldWw.valueChanged.connect(self.swwChanged)
self.sldWc.valueChanged.connect(self.swcChanged)
self.textWw.returnPressed.connect(self.twwChanged)
self.textWc.returnPressed.connect(self.twcChanged)
## 标签列表点击
self.labelListTable.cellDoubleClicked.connect(self.labelListDoubleClick)
self.labelListTable.cellClicked.connect(self.labelListClicked)
self.labelListTable.cellChanged.connect(self.labelListItemChanged)
## 功能区选择
# self.rsShow.currentIndexChanged.connect(self.rsShowModeChange) # 显示模型
for bandCombo in self.bandCombos:
bandCombo.currentIndexChanged.connect(self.rsBandSet) # 设置波段
# self.btnInitGrid.clicked.connect(self.initGrid) # 打开宫格
self.btnFinishedGrid.clicked.connect(self.saveGridLabel)
## 视频相关
self.timer.timeout.connect(self.on_time)
self.videoPlay.clicked.connect(self.on_play)
self.sldTime.valueChanged.connect(self.sframeChanged)
self.textTime.returnPressed.connect(self.tframeChanged)
self.ratio = 20
self.speedComboBox.currentIndexChanged.connect(self.on_speed)
self.preFrameButton.clicked.connect(self.turnPreFrame)
self.nextFrameButton.clicked.connect(self.turnNextFrame)
def initActions(self):
tr = partial(QtCore.QCoreApplication.translate, "APP_EISeg")
action = partial(util.newAction, self)
start = dir()
# 打开/加载/保存
open_image = action(
tr("&打开图像"),
self.openImage,
"open_image",
"OpenImage",
tr("打开一张图像进行标注"), )
open_folder = action(
tr("&打开文件夹"),
self.openFolder,
"open_folder",
"OpenFolder",
tr("打开一个文件夹下所有的图像进行标注"), )
change_output_dir = action(
tr("&改变标签保存路径"),
partial(self.changeOutputDir, None),
"change_output_dir",
"ChangeOutputDir",
tr("改变标签保存的文件夹路径"), )
load_param = action(
tr("&加载模型参数"),
self.changeParam,
"load_param",
"Model",
tr("加载一个模型参数"), )
save = action(
tr("&保存"),
self.exportLabel,
"save",
"Save",
tr("保存图像标签"), )
save_as = action(
tr("&另存为"),
partial(
self.exportLabel, saveAs=True),
"save_as",
"SaveAs",
tr("在指定位置另存为标签"), )
auto_save = action(
tr("&自动保存"),
self.toggleAutoSave,
"auto_save",
"AutoSave",
tr("翻页同时自动保存"),
checkable=True, )
# auto_save.setChecked(self.config.get("auto_save", False))
# 标注
turn_prev = action(
tr("&上一张"),
partial(self.turnImg, -1),
"turn_prev",
"Prev",
tr("翻到上一张图片"), )
turn_next = action(
tr("&下一张"),
partial(self.turnImg, 1),
"turn_next",
"Next",
tr("翻到下一张图片"), )
finish_object = action(
tr("&完成当前目标"),
self.finishObject,
"finish_object",
"Ok",
tr("完成当前目标的标注"), )
clear = action(
tr("&清除所有标注"),
self.clearAll,
"clear",
"Clear",
tr("清除所有标注信息"), )
undo = action(
tr("&撤销"),
self.undoClick,
"undo",
"Undo",
tr("撤销一次点击"), )
redo = action(
tr("&重做"),
self.redoClick,
"redo",
"Redo",
tr("重做一次点击"), )
del_active_polygon = action(
tr("&删除多边形"),
self.delActivePolygon,
"del_active_polygon",
"DeletePolygon",
tr("删除当前选中的多边形"), )
del_all_polygon = action(
tr("&删除所有多边形"),
self.delAllPolygon,
"del_all_polygon",
"DeleteAllPolygon",
tr("删除所有的多边形"), )
largest_component = action(
tr("&保留最大连通块"),
self.toggleLargestCC,
"largest_component",
"SaveLargestCC",
tr("保留最大的连通块"),
checkable=True, )
origional_extension = action(
tr("&标签和图像使用相同拓展名"),
self.toggleOrigExt,
"origional_extension",
"Same",
tr("标签和图像使用相同拓展名,用于图像中有文件名相同但拓展名不同的情况,防止标签覆盖"),
checkable=True, )
save_pseudo = action(
tr("&伪彩色保存"),
partial(self.toggleSave, "pseudo_color"),
"save_pseudo",
"SavePseudoColor",
tr("保存为伪彩色图像"),
checkable=True, )
save_pseudo.setChecked(self.save_status["pseudo_color"])
save_grayscale = action(
tr("&灰度保存"),
partial(self.toggleSave, "gray_scale"),
"save_grayscale",
"SaveGrayScale",
tr("保存为灰度图像,像素的灰度为对应类型的标签"),
checkable=True, )
save_grayscale.setChecked(self.save_status["gray_scale"])
save_json = action(
tr("&JSON保存"),
partial(self.toggleSave, "json"),
"save_json",
"SaveJson",
tr("保存为JSON格式"),
checkable=True, )
save_json.setChecked(self.save_status["json"])
save_coco = action(
tr("&COCO保存"),
partial(self.toggleSave, "coco"),
"save_coco",
"SaveCOCO",
tr("保存为COCO格式"),
checkable=True, )
save_coco.setChecked(self.save_status["coco"])
# test func
self.show_rs_poly = action(
tr("&显示遥感多边形"),
None,
"show_rs_poly",
"Show",
tr("显示遥感大图的多边形结果"),
checkable=True, )
self.show_rs_poly.setChecked(False)
self.grid_message = action(
tr("&启用宫格检测"),
None,
"grid_message",
"Show",
tr("针对每张图片启用宫格检测"),
checkable=True, )
self.grid_message.setChecked(True)
eiseg_med3D = action(
tr("&EISeg-Med3D"),
self.enterEISegMed3D,
"enterEISegMed3D",
"EISegMed3D",
tr("3D医疗交互式分割插件"), )
save_cutout = action(
tr("&抠图保存"),
partial(self.toggleSave, "cutout"),
"save_cutout",
"SaveCutout",
tr("只保留前景,背景设置为背景色"),
checkable=True, )
save_cutout.setChecked(self.save_status["cutout"])
set_cutout_background = action(
tr("&设置抠图背景色"),
self.setCutoutBackground,
"set_cutout_background",
self.cutoutBackground,
tr("抠图后背景像素的颜色"), )
close = action(
tr("&关闭"),
partial(self.saveImage, True),
"close",
"Close",
tr("关闭当前图像"), )
quit = action(
tr("&退出"),
self.close,
"quit",
"Quit",
tr("退出软件"), )
export_label_list = action(
tr("&导出标签列表"),
partial(self.exportLabelList, None),
"export_label_list",
"ExportLabel",
tr("将标签列表导出成标签配置文件"), )
import_label_list = action(
tr("&载入标签列表"),
partial(self.importLabelList, None),
"import_label_list",
"ImportLabel",
tr("从标签配置文件载入标签列表"), )
clear_label_list = action(
tr("&清空标签列表"),
self.clearLabelList,
"clear_label_list",
"ClearLabel",
tr("清空所有的标签"), )
clear_recent = action(
tr("&清除近期文件记录"),
self.clearRecentFile,
"clear_recent",
"ClearRecent",
tr("清除近期标注文件记录"), )
model_widget = action(
tr("&模型选择"),
partial(self.toggleWidget, 0),
"model_widget",
"Net",
tr("隐藏/展示模型选择面板"),
checkable=True, )
data_widget = action(
tr("&数据列表"),
partial(self.toggleWidget, 1),
"data_widget",
"Data",
tr("隐藏/展示数据列表面板"),
checkable=True, )
label_widget = action(
tr("&标签列表"),
partial(self.toggleWidget, 2),
"label_widget",
"Label",
tr("隐藏/展示标签列表面板"),
checkable=True, )
segmentation_widget = action(
tr("&分割设置"),
partial(self.toggleWidget, 3),
"segmentation_widget",
"Setting",
tr("隐藏/展示分割设置面板"),
checkable=True, )
rs_widget = action(
tr("&遥感设置"),
partial(self.toggleWidget, 4),
"rs_widget",
"RemoteSensing",
tr("隐藏/展示遥感设置面板"),
checkable=True, )
mi_widget = action(
tr("&医疗设置"),
partial(self.toggleWidget, 5),
"mi_widget",
"MedicalImaging",
tr("隐藏/展示医疗设置面板"),
checkable=True, )
grid_ann_widget = action(
tr("&N2宫格标注"),
partial(self.toggleWidget, 6),
"grid_ann_widget",
"N2",
tr("隐藏/展示N^2宫格细粒度标注面板"),
checkable=True, )
video_play_widget = action(
tr("&视频播放"),
partial(self.toggleWidget, 7),
"video_play_widget",
"Video",
tr("隐藏/展示视频播放面板"),
checkable=True, )
video_anno_widget = action(
tr("&视频标注"),
partial(self.toggleWidget, 8),
"video_anno_widget",
"VideoAnno",
tr("隐藏/展示视频标注面板"),
checkable=True, )
td_widget = action(
tr("&3D显示"),
partial(self.toggleWidget, 9),
"td_widget",
"3D",
tr("隐藏/展示3D显示面板"),
checkable=True, )
quick_start = action(
tr("&快速入门"),
self.quickStart,
"quick_start",
"Use",
tr("主要功能使用介绍"), )
report_bug = action(
tr("&反馈问题"),
self.reportBug,
"report_bug",
"ReportBug",
tr("通过Github Issue反馈使用过程中遇到的问题。我们会尽快进行修复"), )
edit_shortcuts = action(
tr("&编辑快捷键"),
self.editShortcut,
"edit_shortcuts",
"Shortcut",
tr("编辑软件快捷键"), )
toggle_logging = action(
tr("&调试日志"),
self.toggleLogging,
"toggle_logging",
"Log",
tr("用于观察软件执行过程和进行debug。我们不会自动收集任何日志,可能会希望您在反馈问题时间打开此功能,帮助我们定位问题。"),
checkable=True, )
toggle_logging.setChecked(bool(self.settings.value("log", False)))
use_qt_widget = action(
tr("&使用QT文件窗口"),
self.useQtWidget,
"use_qt_widget",
"Qt",
tr("如果使用文件选择窗口时遇到问题可以选择使用Qt窗口"),
checkable=True, )
# print(
# "use_qt_widget",
# self.settings.value("use_qt_widget", type=bool),
# )
use_qt_widget.setChecked(
self.settings.value(
"use_qt_widget", False, type=bool))
self.actions = util.struct()
for name in dir():
if name not in start:
self.actions.append(eval(name))
def newWidget(text, icon, showAction):
widget = QtWidgets.QMenu(text)
widget.setIcon(util.newIcon(icon))
widget.aboutToShow.connect(showAction)
return widget
recent_files = newWidget(self.tr("近期文件"), "Data", self.updateRecentFile)
recent_params = newWidget(
self.tr("近期模型及参数"), "Net", self.updateModelMenu)
video_recent_params = newWidget(
self.tr("近期视频传播模型及参数"), "Net", self.updateVideoModelMenu)
languages = newWidget(self.tr("语言"), "Language", self.updateLanguage)
self.menus = util.struct(
recent_files=recent_files,
recent_params=recent_params,
video_recent_params=video_recent_params,
languages=languages,
fileMenu=(
open_image,
open_folder,
change_output_dir,
load_param,
clear_recent,
recent_files,
recent_params,
video_recent_params,
None,
save,
save_as,
auto_save,
None,
turn_next,
turn_prev,
close,
None,
quit, ),
labelMenu=(
export_label_list,
import_label_list,
clear_label_list, ),
functionMenu=(
largest_component,
del_active_polygon,
del_all_polygon,
None,
origional_extension,
save_pseudo,
save_grayscale,
save_cutout,
set_cutout_background,
None,
save_json,
save_coco,
None,
# test
self.show_rs_poly,
None,
self.grid_message, ),
showMenu=(
model_widget,
data_widget,
label_widget,
segmentation_widget,
rs_widget,
mi_widget,
grid_ann_widget,
video_play_widget,
video_anno_widget,
td_widget, ),
helpMenu=(
languages,
use_qt_widget,
quick_start,
report_bug,
edit_shortcuts,
toggle_logging, ),
expandMenu=(eiseg_med3D, ),
toolBar=(
finish_object,
clear,
undo,
redo,
turn_prev,
turn_next,
None,
save_pseudo,
save_grayscale,
save_cutout,
save_json,
save_coco,
origional_extension,
None,
largest_component, ), )
def menu(title, actions=None):
menu = self.menuBar().addMenu(title)
if actions:
util.addActions(menu, actions)
return menu
menu(tr("文件"), self.menus.fileMenu)
menu(tr("标注"), self.menus.labelMenu)
menu(tr("功能"), self.menus.functionMenu)
menu(tr("显示"), self.menus.showMenu)
menu(tr("帮助"), self.menus.helpMenu)
menu(tr("更多"), self.menus.expandMenu)
util.addActions(self.toolBar, self.menus.toolBar)
def __setColor(self, action, setting_name):
c = action
color = QtWidgets.QColorDialog.getColor(
QtGui.QColor(*c),
self,
options=QtWidgets.QColorDialog.ShowAlphaChannel, )
action = color.getRgb()
self.settings.setValue(setting_name, [int(c) for c in action])
return action
def on_speed(self, sender):
text = self.speedComboBox.currentText()
self.ratio = int(20 * float(text[4:-1]))
if self.timer.isActive():
self.timer.stop()
self.timer.start(1000 // self.ratio)
def setCutoutBackground(self):
self.cutoutBackground = self.__setColor(self.cutoutBackground,
"cutout_background")
self.actions.set_cutout_background.setIcon(
util.newIcon(self.cutoutBackground))
def editShortcut(self):
self.ShortcutWidget.center()
self.ShortcutWidget.show()
# 多语言
def updateLanguage(self):
self.menus.languages.clear()
langs = os.listdir(osp.join(pjpath, "util/translate"))
langs = [n.split(".")[0] for n in langs if n.endswith("qm")]
langs.append("中文")
for lang in langs:
if lang == self.currLanguage:
continue
entry = util.newAction(
self,
lang,
partial(self.changeLanguage, lang),
None,
lang if lang != "Arabic" else "Egypt", )
self.menus.languages.addAction(entry)
def changeLanguage(self, lang):
self.settings.setValue("language", lang)
self.warn(self.tr("切换语言"), self.tr("切换语言需要重启软件才能生效"))
# 近期图像
def updateRecentFile(self):
menu = self.menus.recent_files
menu.clear()
recentFiles = self.settings.value(
"recent_files", QVariant([]), type=list)
files = [f for f in recentFiles if osp.exists(f)]
for i, f in enumerate(files):
icon = util.newIcon("File")
action = QtWidgets.QAction(icon, "&【%d】 %s" %
(i + 1, QtCore.QFileInfo(f).fileName()),
self)
action.triggered.connect(partial(self.openRecentImage, f))
menu.addAction(action)
if len(files) == 0:
menu.addAction(self.tr("无近期文件"))
self.settings.setValue("recent_files", files)
def addRecentFile(self, path):
if not osp.exists(path):
return
paths = self.settings.value("recent_files", QVariant([]), type=list)
if path not in paths:
paths.append(path)
if len(paths) > 15:
del paths[0]
self.settings.setValue("recent_files", paths)
self.updateRecentFile()
def clearRecentFile(self):
self.settings.remove("recent_files")
self.statusbar.showMessage(self.tr("已清除最近打开文件"), 10000)
# 模型加载
def updateModelMenu(self):
menu = self.menus.recent_params
menu.clear()
self.recentModels = [
m for m in self.recentModels if osp.exists(m["param_path"])
]
for idx, m in enumerate(self.recentModels):
icon = util.newIcon("Model")
action = QtWidgets.QAction(
icon,
f"{osp.basename(m['param_path'])}",
self, )
action.triggered.connect(
partial(self.setModelParam, m["param_path"]))
menu.addAction(action)
if len(self.recentModels) == 0:
menu.addAction(self.tr("无近期模型记录"))
self.settings.setValue("recent_params", self.recentModels)
def updateVideoModelMenu(self):
menu = self.menus.video_recent_params
menu.clear()
self.video_recentModels = [
m for m in self.video_recentModels if osp.exists(m["param_path"])
]
for idx, m in enumerate(self.video_recentModels):
icon = util.newIcon("Model")
action = QtWidgets.QAction(
icon,
f"{osp.basename(m['param_path'])}",
self, )
action.triggered.connect(
partial(self.setVideoModelParam, m["param_path"]))
menu.addAction(action)
if len(self.video_recentModels) == 0:
menu.addAction(self.tr("无近期视频传播模型记录"))
self.settings.setValue("video_recent_params", self.video_recentModels)
def setModelParam(self, paramPath):
res = self.changeParam(paramPath)
if res:
return True
return False
def setVideoModelParam(self, paramPath):
res = self.changePropgationParam(paramPath)
if res:
return True
return False
def changeParam(self, param_path: str=None):
if not param_path:
filters = self.tr("Paddle静态模型权重文件(*.pdiparams)")
start_path = ("." if len(self.recentModels) == 0 else
osp.dirname(self.recentModels[-1]["param_path"]))
if self.settings.value("use_qt_widget", False, type=bool):
options = QtWidgets.QFileDialog.DontUseNativeDialog
else:
options = QtWidgets.QFileDialog.ReadOnly
param_path, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
self.tr("选择传播模型参数") + " - " + __APPNAME__,
start_path,
filters,
options=options, )
# QtWidgets.QFileDialog.DontUseNativeDialog
if not param_path:
return False
# 中文路径打不开
if check_cn(param_path):
self.warn(self.tr("参数路径存在中文"), self.tr("请修改参数路径为非中文路径!"))
return False
success, res = self.controller.setModel(param_path)
if success:
model_dict = {"param_path": param_path}
if model_dict not in self.recentModels:
self.recentModels.insert(0, model_dict)
if len(self.recentModels) > 10:
del self.recentModels[-1]
else: # 如果存在移动位置,确保加载最近模型的正确
self.recentModels.remove(model_dict)
self.recentModels.insert(0, model_dict)
self.settings.setValue("recent_models", self.recentModels)
self.statusbar.showMessage(
osp.basename(param_path) + self.tr(" 模型加载成功"), 10000)
return True
else:
self.warnException(res)
return False
def changePropgationParam(self, param_path: str=None):
if not param_path:
filters = self.tr("Paddle静态模型权重文件(*.pdiparams)")
start_path = (
".") if len(self.video_recentModels) == 0 else osp.dirname(
self.video_recentModels[-1]["param_path"])
if self.settings.value("use_qt_widget", False, type=bool):
options = QtWidgets.QFileDialog.DontUseNativeDialog
else:
options = QtWidgets.QFileDialog.ReadOnly
param_path, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
self.tr("选择模型参数") + " - " + __APPNAME__,
start_path,
filters,
options=options, )
# QtWidgets.QFileDialog.DontUseNativeDialog
if not param_path:
return False
# 中文路径打不开
if check_cn(param_path):
self.warn(self.tr("参数路径存在中文"), self.tr("请修改参数路径为非中文路径!"))
return False
success, res = self.video.set_model(param_path)
if success:
model_dict = {"param_path": param_path}
if model_dict not in self.video_recentModels:
self.video_recentModels.insert(0, model_dict)
if len(self.recentModels) > 10:
del self.recentModels[-1]
else: # 如果存在移动位置,确保加载最近模型的正确
self.video_recentModels.remove(model_dict)
self.video_recentModels.insert(0, model_dict)
self.settings.setValue("video_recent_models",
self.video_recentModels)
self.statusbar.showMessage(
osp.basename(param_path) + self.tr("视频传播模型加载成功"), 10000)
return True
else:
self.warnException(res)
return False
def chooseMode(self):
self.predictor_params["predictor_params"][
"with_mask"] = self.cheWithMask.isChecked()
self.controller.reset_predictor(predictor_params=self.predictor_params)
if self.cheWithMask.isChecked():
self.statusbar.showMessage(self.tr("掩膜已启用"), 10000)
else:
self.statusbar.showMessage(self.tr("掩膜已关闭"), 10000)
def loadRecentModelParam(self):
if len(self.recentModels) == 0:
self.statusbar.showMessage(self.tr("没有最近使用模型信息,请加载模型"), 10000)
return
m = self.recentModels[0]
param_path = m["param_path"]
self.setModelParam(param_path)
def loadVideoRecentModelParam(self):
if len(self.video_recentModels) == 0:
self.statusbar.showMessage(self.tr("没有最近使用的视频传播模型信息,请加载模型"), 10000)
return
m = self.video_recentModels[0]
param_path = m["param_path"]
self.setVideoModelParam(param_path)
# 标签列表
def importLabelList(self, filePath=None):
if filePath is None:
if self.settings.value("use_qt_widget", False, type=bool):
options = QtWidgets.QFileDialog.DontUseNativeDialog
else:
options = QtWidgets.QFileDialog.ReadOnly
filters = self.tr("标签配置文件") + " (*.txt)"
filePath, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
self.tr("选择标签配置文件路径") + " - " + __APPNAME__,
".",
filters,
options=options, )
filePath = normcase(filePath)
if not osp.exists(filePath):
return
self.controller.importLabel(filePath)
logger.info(f"Loaded label list: {self.controller.labelList.labelList}")
self.refreshLabelList()
def exportLabelList(self, savePath: str=None):
if len(self.controller.labelList) == 0:
self.warn(self.tr("没有需要保存的标签"), self.tr("请先添加标签之后再进行保存!"))
return
if savePath is None:
filters = self.tr("标签配置文件") + "(*.txt)"
dlg = QtWidgets.QFileDialog(
self,
self.tr("保存标签配置文件"),
".",
filters, )
dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False)
if self.settings.value("use_qt_widget", False, type=bool):
options = QtWidgets.QFileDialog.DontUseNativeDialog
else:
options = QtWidgets.QFileDialog.DontUseCustomDirectoryIcons
dlg.setDefaultSuffix("txt")
dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
savePath, _ = dlg.getSaveFileName(
self,
self.tr("选择保存标签配置文件路径") + " - " + __APPNAME__,
".",
filters,
options=options, )
self.controller.exportLabel(savePath)
def addLabel(self, idx=None, txt="", c=None):
c = self.colorMap.get_color()
table = self.labelListTable
idx = table.rowCount()
table.insertRow(table.rowCount())
self.controller.addLabel(idx + 1, txt, c)
numberItem = QTableWidgetItem(str(idx + 1))
numberItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 0, numberItem)
table.setItem(idx, 1, QTableWidgetItem())
colorItem = QTableWidgetItem()
colorItem.setBackground(QtGui.QColor(c[0], c[1], c[2]))
colorItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 2, colorItem)
delItem = QTableWidgetItem()
delItem.setIcon(util.newIcon("Clear"))
delItem.setTextAlignment(Qt.AlignCenter)
delItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 3, delItem)
self.adjustTableSize()
self.labelListClicked(self.labelListTable.rowCount() - 1, 0)
def adjustTableSize(self):
self.labelListTable.horizontalHeader().setDefaultSectionSize(25)
self.labelListTable.horizontalHeader().setSectionResizeMode(
0, QtWidgets.QHeaderView.Fixed)
self.labelListTable.horizontalHeader().setSectionResizeMode(
3, QtWidgets.QHeaderView.Fixed)
self.labelListTable.horizontalHeader().setSectionResizeMode(
2, QtWidgets.QHeaderView.Fixed)
self.labelListTable.setColumnWidth(2, 50)
def clearLabelList(self):
if len(self.controller.labelList) == 0:
return True
res = self.warn(
self.tr("清空标签列表?"),
self.tr("请确认是否要清空标签列表"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Cancel:
return False
self.controller.labelList.clear()
if self.controller:
self.controller.label_list = []
self.controller.curr_label_number = 0
self.labelListTable.clear()
self.labelListTable.setRowCount(0)
return True
def refreshLabelList(self):
table = self.labelListTable
table.clearContents()
table.setRowCount(len(self.controller.labelList))
table.setColumnCount(4)
for idx, lab in enumerate(self.controller.labelList):
numberItem = QTableWidgetItem(str(lab.idx))
numberItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 0, numberItem)
table.setItem(idx, 1, QTableWidgetItem(lab.name))
c = lab.color
colorItem = QTableWidgetItem()
colorItem.setBackground(QtGui.QColor(c[0], c[1], c[2]))
colorItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 2, colorItem)
delItem = QTableWidgetItem()
delItem.setIcon(util.newIcon("Clear"))
delItem.setTextAlignment(Qt.AlignCenter)
delItem.setFlags(QtCore.Qt.ItemIsEnabled)
table.setItem(idx, 3, delItem)
self.adjustTableSize()
cols = [0, 1, 3]
for idx in cols:
table.resizeColumnToContents(idx)
self.adjustTableSize()
def labelListDoubleClick(self, row, col):
if col != 2:
return
table = self.labelListTable
color = QtWidgets.QColorDialog.getColor()
if color.getRgb() == (0, 0, 0, 255):
return
table.item(row, col).setBackground(color)
self.controller.labelList[row].color = color.getRgb()[:3]
if self.controller:
self.controller.label_list = self.controller.labelList
for p in self.scene.polygon_items:
idlab = self.controller.labelList.getLabelById(p.labelIndex)
if idlab is not None:
color = idlab.color
p.setColor(color, color)
self.labelListClicked(row, 0)
@property
def currLabelIdx(self):
return self.controller.curr_label_number - 1
def labelListClicked(self, row, col):
table = self.labelListTable
if col == 3:
labelIdx = int(table.item(row, 0).text())
if self.status == self.EDITING:
if self.checkLabel(labelIdx):
self.controller.labelList.remove(labelIdx)
table.removeRow(row)
else:
self.warn(
self.tr("无法删除"),
self.tr("当前多边形中存在此标签"), QMessageBox.Yes)
elif self.status == self.ANNING:
self.warn(
self.tr("无法删除"), self.tr("交互式标注模式无法删除标签"), QMessageBox.Yes)
if col == 0 or col == 1:
for cl in range(2):
for idx in range(len(self.controller.labelList)):
table.item(idx,
cl).setBackground(QtGui.QColor(255, 255, 255))
table.item(row, cl).setBackground(QtGui.QColor(48, 140, 198))
table.item(row, 0).setSelected(True)
if self.controller:
self.controller.setCurrLabelIdx(int(table.item(row, 0).text()))
self.controller.label_list = self.controller.labelList
def labelListItemChanged(self, row, col):
self.colorMap.usedColors = self.controller.labelList.colors
try:
if col == 1:
name = self.labelListTable.item(row, col).text()
self.controller.labelList[row].name = name
except:
pass
# 多边形标注
def createPoly(self, curr_polygon, color):
if curr_polygon is None:
return
for points in curr_polygon:
if len(points) < 3:
continue
poly = PolygonAnnotation(
self.controller.labelList[self.currLabelIdx].idx,
self.controller.image.shape,
self.delPolygon,
self.setDirty,
color,
color,
self.opacity, )
poly.labelIndex = self.controller.labelList[self.currLabelIdx].idx
self.scene.addItem(poly)
self.scene.polygon_items.append(poly)
for p in points:
poly.addPointLast(QtCore.QPointF(p[0], p[1]))
self.setDirty(True)
def delActivePolygon(self):
for idx, polygon in enumerate(self.scene.polygon_items):
if polygon.hasFocus():
res = self.warn(
self.tr("确认删除?"),
self.tr("确认要删除当前选中多边形标注?"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Yes:
self.delPolygon(polygon)
def delPolygon(self, polygon):
polygon.remove()
if self.save_status["coco"]:
if polygon.coco_id:
self.coco.delAnnotation(
polygon.coco_id,
self.coco.imgNameToId[osp.basename(self.imagePath)], )
self.setDirty(True)
def delAllPolygon(self):
for p in self.scene.polygon_items[::-1]: # 删除所有多边形
self.delPolygon(p)
def delActivePoint(self):
for polygon in self.scene.polygon_items:
polygon.removeFocusPoint()
# 图片/标签 io
def getMask(self):
if not self.controller or self.controller.image is None:
return
s = self.controller.imgShape
pesudo = np.zeros([s[0], s[1]])
# 覆盖顺序,从上往下
# TODO: 是标签数值大的会覆盖小的吗?
# A: 是列表中上面的覆盖下面的,由于标签可以移动,不一定是大小按顺序覆盖
# RE: 我们做医学的时候覆盖比较多,感觉一般是数值大的标签覆盖数值小的标签。按照上面覆盖下面的话可能跟常见的情况正好是反过来的,感觉可能从下往上覆盖会比较好
len_lab = self.labelListTable.rowCount()
for i in range(len_lab - 1, -1, -1):
idx = int(self.labelListTable.item(len_lab - i - 1, 0).text())
for poly in self.scene.polygon_items:
if poly.labelIndex == idx:
pts = np.int32([np.array(poly.scnenePoints)])
cv2.fillPoly(pesudo, pts=pts, color=idx)
return pesudo
def openRecentImage(self, file_path):
self.queueEvent(partial(self.loadImage, file_path))
self.listFiles.addItems([file_path.replace("\\", "/")])
self.currIdx = self.listFiles.count() - 1
self.listFiles.setCurrentRow(self.currIdx) # 移动位置
self.imagePaths.append(file_path)
def openImage(self, filePath: str=None):
# 在triggered.connect中使用不管默认filePath为什么返回值都为False
if not isinstance(filePath, str) or filePath is False:
prompts = ["图片", "医学影像", "遥感影像", "视频"]
filters = ""
for fmts, p in zip(self.formats, prompts):
filters += f"{p} ({' '.join(['*' + f for f in fmts])}) ;; "
filters = filters[:-3]
recentPath = self.settings.value("recent_files", [])
if len(recentPath) == 0:
recentPath = "."
else:
recentPath = osp.dirname(recentPath[0])
if self.settings.value("use_qt_widget", False, type=bool):
options = QtWidgets.QFileDialog.DontUseNativeDialog
else:
options = QtWidgets.QFileDialog.ReadOnly
filePath, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
self.tr("选择待标注图片") + " - " + __APPNAME__,
recentPath,
filters,
options=options, )
if len(filePath) == 0: # 用户没选就直接关闭窗口
return
if osp.splitext(filePath)[-1] in self.video_ext:
if not paddle.device.is_compiled_with_cuda(
): # TODO: 可以使用GPU却返回False
self.warn(
self.tr("请在gpu电脑上进行视频标注"),
self.tr("准备进行视频标注,由于视频标注需要一定计算,请尽量确保在gpu的电脑上进行操作!"))
filePath = normcase(filePath)
if not self.loadImage(filePath):
return False
# 3. 添加记录
self.listFiles.addItems([filePath])
self.currIdx = self.listFiles.count() - 1
self.listFiles.setCurrentRow(self.currIdx) # 移动位置
self.imagePaths.append(filePath)
return True
def openFolder(self, inputDir: str=None):
# 1. 如果没传文件夹,弹框让用户选
if not isinstance(inputDir, str):
recentPath = self.settings.value("recent_files", [])
if len(recentPath) == 0:
recentPath = "."
else:
recentPath = osp.dirname(recentPath[-1])
options = (QtWidgets.QFileDialog.ShowDirsOnly |
QtWidgets.QFileDialog.DontResolveSymlinks)
if self.settings.value("use_qt_widget", False, type=bool):
options = options | QtWidgets.QFileDialog.DontUseNativeDialog
inputDir = QtWidgets.QFileDialog.getExistingDirectory(
self,
self.tr("选择待标注图片文件夹") + " - " + __APPNAME__,
recentPath,
options, )
if not osp.exists(inputDir):
return
# 2. 关闭当前图片,清空文件列表
self.saveImage(close=True)
self.imagePaths = []
self.listFiles.clear()
# 3. 扫描文件夹下所有图片
# 3.1 获取所有文件名
imagePaths = os.listdir(inputDir)
exts = tuple(f for fmts in self.formats for f in fmts)
imagePaths = [n for n in imagePaths if n.lower().endswith(exts)]
imagePaths = [n for n in imagePaths if not n[0] == "."]
imagePaths.sort()
if len(imagePaths) == 0:
return
# 3.2 设置默认输出路径
if self.outputDir is None:
# 没设置为文件夹下的 label 文件夹
self.outputDir = osp.join(inputDir, "label")
if not osp.exists(self.outputDir):
os.makedirs(self.outputDir)
# 3.3 有重名图片,标签保留原来拓展名
names = []
for name in imagePaths:
name = osp.splitext(name)[0]
if name not in names:
names.append(name)
else:
self.toggleOrigExt(True)
break
imagePaths = [osp.join(inputDir, n) for n in imagePaths]
for p in imagePaths:
p = normcase(p)
self.imagePaths.append(p)
self.listFiles.addItem(p)
# 3.4 加载已有的标注
if self.outputDir is not None and osp.exists(self.outputDir):
self.changeOutputDir(self.outputDir)
if len(self.imagePaths) != 0:
self.currIdx = 0
self.turnImg(0)
self.inputDir = inputDir
def loadImage(self, path):
if self.controller.model is None:
self.warn("未检测到模型", "请先加载模型参数")
return
# 1. 拒绝None和不存在的路径,关闭当前图像
if not path:
return
path = normcase(path)
if not osp.exists(path):
return
self.imagePath = path
self.saveImage(True) # 关闭当前图像
self.eximgsInit() # TODO: 将grid的部分整合到saveImage里
# 2. 判断图像类型,打开
# TODO: 加用户指定类型的功能
image = None
# 直接if会报错,因为打开遥感图像后多波段不存在,现在把遥感图像的单独抽出来了
# 自然图像
if path.lower().endswith(tuple(self.formats[0])):
image = cv2.imdecode(np.fromfile(path, dtype=np.uint8), 1)
image = image[:, :, ::-1] # BGR转RGB
if self.grid_message.isChecked():
if checkOpenGrid(image, self.thumbnail_min):
if self.loadGrid(image, False):
image, _ = self.grid.getGrid(0, 0)
# 自然图像不进行缩小
else:
if self.dockWidgets["grid"][0].isVisible() is True:
self.grid = Grids(image)
self.initGrid()
image, _ = self.grid.getGrid(0, 0)
# 医学影像
if path.lower().endswith(tuple(self.formats[1])):
if not self.dockStatus[5]:
res = self.warn(
self.tr("未启用医疗组件"),
self.tr("加载医疗影像需启用医疗组件,是否立即启用?"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Cancel:
return False
self.toggleWidget(5)
if not self.dockStatus[5]:
return False
image = med.dcm_reader(path) # TODO: 添加多层支持
if image.shape[-1] != 1:
self.warn("医学影像打开错误", "暂不支持打开多层医学影像")
return False
maxValue = np.max(image) # 根据数据模态自适应窗宽窗位
minValue = np.min(image)
if minValue == 0:
ww = maxValue
wc = int(maxValue / 2)
else:
ww = maxValue + int(abs(minValue))
wc = int((minValue + maxValue) / 2)
self.sldWw.setValue(int(ww))
self.textWw.setText(str(ww))
self.sldWc.setValue(int(wc))
self.textWc.setText(str(wc))
self.controller.rawImage = self.image = image
image = med.windowlize(image, self.ww, self.wc)
# 遥感图像
if path.lower().endswith(tuple(self.formats[
2])): # imghdr.what(path) == "tiff":
if not self.dockStatus[4]:
res = self.warn(
self.tr("未打开遥感组件"),
self.tr("打开遥感图像需启用遥感组件,是否立即启用?"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Cancel:
return False
self.toggleWidget(4)
if not self.dockStatus[4]:
return False
self.raster = Raster(path)
gi = self.raster.showGeoInfo()
self.edtGeoinfo.setText(
self.tr("● 波段数:") + gi[0] + "\n" + self.tr("● 数据类型:") + gi[1] +
"\n" + self.tr("● 行数:") + gi[2] + "\n" + self.tr("● 列数:") + gi[
3] + "\n" + "● EPSG:" + gi[4])
if max(self.rsRGB) > self.raster.geoinfo.count:
self.rsRGB = [1, 1, 1]
self.raster.setBand(self.rsRGB)
if self.grid_message.isChecked():
if self.raster.checkOpenGrid(self.thumbnail_min):
if self.loadGrid(self.raster):
image, _ = self.raster.getGrid(0, 0)
else:
image, _ = self.raster.getArray()
else:
image, _ = self.raster.getArray()
else:
if self.dockWidgets["grid"][0].isVisible() is True:
self.grid = RSGrids(self.raster)
self.raster.open_grid = True
self.initGrid()
image, _ = self.raster.getGrid(0, 0)
else:
image, _ = self.raster.getArray()
self.updateBandList()
# self.updateSlideSld(True)
else:
self.edtGeoinfo.setText(self.tr("无"))
# 视频
if path.lower().endswith(tuple(self.formats[3])): # mp4
if not self.dockStatus[7]:
res = self.warn(
self.tr("未启用视频组件"),
self.tr("加载视频需启用视频组件,是否立即启用?"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Cancel:
return False
self.toggleWidget(7)
self.toggleWidget(8)
if not self.dockStatus[7]:
return False
if not self.dockStatus[8]:
return False
# self.video_masks = None
self.video_images, self.fps = self.video.set_video(path)
self.video_masks = np.zeros(
(self.video.num_frames, self.video.height, self.video.width),
dtype=np.uint8)
self.sldTime.setMaximum(self.video.num_frames - 1)
image = self.video_images[self.video.cursur]
self.sldTime.setProperty("value", 0)
# 清空3d显示
if self.TDDock.isVisible():
self.vtkWidget.init()
# TODO: 处理
# 如果没找到图片的reader
if image is None:
self.warn("打开图像失败", f"未找到{path}文件对应的读取程序")
return
self.image = image
self.controller.setImage(image)
self.updateImage(True)
# 2. 加载标签
self.loadLabel(path)
self.addRecentFile(path)
return True
def loadLabel(self, imgPath):
if imgPath == "":
return None
if self.video_images is not None:
videoName = osp.splitext(osp.basename(imgPath))[0]
maskPath = None
for path in self.labelPaths:
if osp.basename(path) == videoName:
maskPath = osp.join(path, 'mask')
if not maskPath:
return
for cursur in range(self.video.num_frames):
h, w = self.video_masks[cursur].shape
frame_mask = np.zeros([h, w])
pseudo = cv2.imread(
osp.join(maskPath, '{:05d}.png'.format(cursur)))
for lab in self.controller.labelList:
frame_mask[(pseudo == lab.color[::-1])[:, :, 0]] = lab.idx
self.video_masks[cursur] = frame_mask
return
# 1. 读取json格式标签
if self.save_status["json"]:
def getName(path):
return osp.splitext(osp.basename(path))[0]
imgName = getName(imgPath)
labelPath = None
for path in self.labelPaths:
if not path.endswith(".json"):
continue
if self.origExt:
if getName(path) == osp.basename(imgPath):
labelPath = path
break
else:
if getName(path) == imgName:
labelPath = path
break
if not labelPath:
return
labels = json.loads(open(labelPath, "r").read())
for label in labels:
color = label["color"]
labelIdx = label["labelIdx"]
points = label["points"]
poly = PolygonAnnotation(
labelIdx,
self.controller.image.shape,
self.delPolygon,
self.setDirty,
color,
color,
self.opacity, )
self.scene.addItem(poly)
self.scene.polygon_items.append(poly)
for p in points:
poly.addPointLast(QtCore.QPointF(p[0], p[1]))
# 2. 读取coco格式标签
if self.save_status["coco"]:
imgId = self.coco.imgNameToId.get(osp.basename(imgPath), None)
if imgId is None:
return
anns = self.coco.imgToAnns[imgId]
for ann in anns:
xys = ann["segmentation"][0]
points = []
for idx in range(0, len(xys), 2):
points.append([xys[idx], xys[idx + 1]])
labelIdx = ann["category_id"]
idlab = self.controller.labelList.getLabelById(labelIdx)
if idlab is not None:
color = idlab.color
poly = PolygonAnnotation(
ann["category_id"],
self.controller.image.shape,
self.delPolygon,
self.setDirty,
color,
color,
self.opacity,
ann["id"], )
self.scene.addItem(poly)
self.scene.polygon_items.append(poly)
for p in points:
poly.addPointLast(QtCore.QPointF(p[0], p[1]))
def turnImg(self, delta, list_click=False):
if (self.grid is None or self.grid.curr_idx is None) or list_click:
# 1. 检查是否有图可翻,保存标签
self.currIdx += delta
if self.currIdx >= len(self.imagePaths) or self.currIdx < 0:
self.currIdx -= delta
if delta == 1:
self.statusbar.showMessage(self.tr(f"没有后一张图片"))
else:
self.statusbar.showMessage(self.tr(f"没有前一张图片"))
self.saveImage(False)
return
else:
self.saveImage(True)
# 2. 打开新图
self.loadImage(self.imagePaths[self.currIdx])
self.listFiles.setCurrentRow(self.currIdx)
else:
self.turnGrid(delta)
self.setDirty(False)
def imageListClicked(self):
if not self.controller:
self.warn(self.tr("模型未加载"), self.tr("尚未加载模型,请先加载模型!"))
self.changeParam()
if not self.controller:
return
if self.controller.is_incomplete_mask:
self.exportLabel()
toRow = self.listFiles.currentRow()
delta = toRow - self.currIdx
self.turnImg(delta, True)
def finishObject(self):
if not self.controller or self.image is None:
return
current_mask, curr_polygon = self.controller.finishObject(
building=self.boundaryRegular.isChecked())
if curr_polygon is not None:
self.updateImage()
if current_mask is not None:
# current_mask = current_mask.astype(np.uint8) * 255
# polygon = util.get_polygon(current_mask)
color = self.controller.labelList[self.currLabelIdx].color
self.createPoly(curr_polygon, color)
# 状态改变
if self.status == self.EDITING:
self.status = self.ANNING
for p in self.scene.polygon_items:
p.setAnning(isAnning=True)
else:
self.status = self.EDITING
for p in self.scene.polygon_items:
p.setAnning(isAnning=False)
current_mask = self.getMask()
if self.video_images is not None:
if current_mask.max() != 0:
self.video_masks[self.video.cursur] = current_mask
def completeLastMask(self):
# 返回最后一个标签是否完成,false就是还有带点的
if not self.controller or self.controller.image is None:
return True
if not self.controller.is_incomplete_mask:
return True
res = self.warn(
self.tr("完成最后一个目标?"),
self.tr("是否完成最后一个目标的标注,不完成不会进行保存。"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Yes:
self.finishObject()
self.exportLabel()
self.setDirty(False)
return True
return False
def saveImage(self, close=False):
if self.controller and self.controller.image is not None:
# 1. 完成正在交互式标注的标签
self.completeLastMask()
# 2. 进行保存
if self.isDirty:
if self.actions.auto_save.isChecked():
self.exportLabel()
else:
res = self.warn(
self.tr("保存标签?"),
self.tr("标签尚未保存,是否保存标签"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Yes:
self.exportLabel()
self.setDirty(False)
if close:
# 3. 清空多边形标注,删掉图片
for p in self.scene.polygon_items[::-1]:
p.remove()
self.scene.polygon_items = []
self.controller.resetLastObject()
self.updateImage()
self.controller.image = None
if close:
self.annImage.setPixmap(QPixmap())
if self.video_images is not None and self.video_masks is not None:
self.reset_video()
def reset_video(self):
self.video_images = None
self.video_masks = None
self.timer.stop()
self.textTime.setText(str(0))
self.videoPlay.setText(self.tr("播放"))
self.videoPlay.setIcon(
QtGui.QIcon(osp.join(pjpath, "resource/Play.png")))
self.ratio = 20
self.speedComboBox.setCurrentIndex(2)
self.video.reset()
def exportLabel(self, saveAs=False, savePath=None, lab_input=None):
# 1. 需要处于标注状态
if not self.controller or self.controller.image is None:
return
# 2. 完成正在交互式标注的标签
self.completeLastMask()
# 3. 确定保存路径
# 3.1 如果参数指定了保存路径直接存到savePath
if not savePath:
if not saveAs and self.outputDir is not None:
# 3.2 指定了标签文件夹,而且不是另存为:根据标签文件夹和文件名出保存路径
name, ext = osp.splitext(osp.basename(self.imagePath))
if not self.origExt:
ext = ".png"
savePath = osp.join(
self.outputDir,
name + ext, )
if self.video_images is not None and self.video_masks is not None:
savePath = osp.join(self.outputDir, name)
os.makedirs(savePath, exist_ok=True)
else:
# 3.3 没有指定标签存到哪,或者是另存为:弹框让用户选
savePath = self.chooseSavePath()
if savePath is None or not osp.exists(osp.dirname(savePath)):
return
if savePath not in self.labelPaths:
self.labelPaths.append(savePath)
# 视频帧保存&视频保存
if self.video_masks is not None:
if osp.exists(savePath):
res = self.warn(
self.tr("文件夹已经存在"),
self.tr("该文件夹下不为空,您确定继续保存在此路径下吗?"),
QMessageBox.Yes | QMessageBox.Cancel, )
if res == QMessageBox.Cancel:
return
os.makedirs(savePath, exist_ok=True)
if osp.isdir(savePath):
mask_dir = osp.join(savePath, 'mask')
overlay_dir = osp.join(savePath, 'overlay')
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(overlay_dir, exist_ok=True)
progress = QtWidgets.QProgressDialog(self)
progress.setWindowTitle("请稍等")
progress.setLabelText("正在保存...")
progress.setCancelButtonText("取消")
progress.setMinimumDuration(5)
progress.setWindowModality(Qt.WindowModal)
progress.setRange(0, self.video.num_frames)
videoname = savePath + "_overlay.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
h, w = self.video_masks[0].shape
videoWrite = cv2.VideoWriter(videoname, fourcc, self.fps,
(w, h))
for i in range(0, self.video.num_frames):
# Save mask
mask = self.video_masks[i].astype('uint8')
pseudo = np.zeros([h, w, 3])
# mask = self.controller.result_mask
# print(pseudo.shape, mask.shape)
for lab in self.controller.labelList:
pseudo[mask == lab.idx, :] = lab.color[::-1]
cv2.imwrite(
os.path.join(mask_dir, '{:05d}.png'.format(i)), pseudo)
# Save overlay
overlay = overlay_davis(self.video_images[i],
self.video_masks[i], self.opacity,
self.controller.palette)
videoWrite.write(overlay[:, :, ::-1]) # write video
overlay = Image.fromarray(overlay)
overlay.save(
os.path.join(overlay_dir, '{:05d}.png'.format(i)))
progress.setValue(i)
if progress.wasCanceled():
# QMessageBox.warning(self, "提示", "保存失败")
break
progress.setValue(self.video.num_frames)
videoWrite.release()
self.setDirty(False)
self.statusbar.showMessage(
self.tr("视频帧成功保存至") + " " + savePath, 5000)
return
if lab_input is None:
mask_output = self.getMask()
s = self.controller.imgShape
else:
mask_output = lab_input
s = lab_input.shape
# BUG: 如果用了多边形标注从多边形生成mask
# 4.1 保存灰度图
if self.save_status["gray_scale"]:
if self.raster is not None:
# FIXME: when big map saved, self.raster is None,
# so adjust polygon can't saved in tif's mask.
pathHead, _ = osp.splitext(savePath)
# if self.rsSave.isChecked():
tifPath = pathHead + "_mask.tif"
self.raster.saveMask(mask_output, tifPath)
if self.shpSave.isChecked():
shpPath = pathHead + ".shp"
print(rs.save_shp(shpPath, tifPath))
else:
ext = osp.splitext(savePath)[1]
cv2.imencode(ext, mask_output)[1].tofile(savePath)
# self.labelPaths.append(savePath)
# 4.2 保存伪彩色
if self.save_status["pseudo_color"]:
pseudoPath, ext = osp.splitext(savePath)
pseudoPath = pseudoPath + "_pseudo" + ext
pseudo = np.zeros([s[0], s[1], 3], dtype="uint8")
# mask = self.controller.result_mask
mask = mask_output
# print(pseudo.shape, mask.shape)
for lab in self.controller.labelList:
pseudo[mask == lab.idx, :] = lab.color[::-1]
cv2.imencode(ext, pseudo)[1].tofile(pseudoPath)
# 4.3 保存前景抠图
if self.save_status["cutout"]:
mattingPath, ext = osp.splitext(savePath)
mattingPath = mattingPath + "_cutout" + ext
img = np.ones([s[0], s[1], 4], dtype="uint8") * 255
cim = cv2.resize(self.controller.image.copy(), (s[1], s[0]))
img[:, :, :3] = cim
img[mask_output == 0] = self.cutoutBackground
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
cv2.imencode(ext, img)[1].tofile(mattingPath)
# 4.4 保存json
if self.save_status["json"]:
polygons = self.scene.polygon_items
labels = []
for polygon in polygons:
l = self.controller.labelList[polygon.labelIndex - 1]
label = {
"name": l.name,
"labelIdx": l.idx,
"color": l.color,
"points": [],
}
for p in polygon.scnenePoints:
label["points"].append(p)
labels.append(label)
if self.origExt:
jsonPath = savePath + ".json"
else:
jsonPath = osp.splitext(savePath)[0] + ".json"
open(jsonPath, "w", encoding="utf-8").write(json.dumps(labels))
self.labelPaths.append(jsonPath)
# 4.5 保存coco
if self.save_status["coco"]:
if not self.coco.hasImage(osp.basename(self.imagePath)):
imgId = self.coco.addImage(
osp.basename(self.imagePath), s[1], s[0])
else:
imgId = self.coco.imgNameToId[osp.basename(self.imagePath)]
for polygon in self.scene.polygon_items:
points = []
for p in polygon.scnenePoints:
for val in p:
points.append(val)
if not polygon.coco_id:
annId = self.coco.addAnnotation(imgId, polygon.labelIndex,
points)
polygon.coco_id = annId
else:
self.coco.updateAnnotation(polygon.coco_id, imgId, points)
for lab in self.controller.labelList:
if self.coco.hasCat(lab.idx):
self.coco.updateCategory(lab.idx, lab.name, lab.color)
else:
self.coco.addCategory(lab.idx, lab.name, lab.color)
saveDir = (self.outputDir
if self.outputDir is not None else osp.dirname(savePath))
cocoPath = osp.join(saveDir, "annotations.json")
open(
cocoPath, "w",
encoding="utf-8").write(json.dumps(self.coco.dataset))
self.setDirty(False)
self.statusbar.showMessage(self.tr("标签成功保存至") + " " + savePath, 5000)
def chooseSavePath(self):
formats = [
"*.{}".format(fmt.data().decode())
for fmt in QtGui.QImageReader.supportedImageFormats()
]
filters = "Label file (%s)" % " ".join(formats)
dlg = QtWidgets.QFileDialog(
self,
self.tr("保存标签文件路径"),
osp.dirname(self.imagePath),
filters, )
dlg.setDefaultSuffix("png")
dlg.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
dlg.setOption(QtWidgets.QFileDialog.DontConfirmOverwrite, False)
dlg.setOption(QtWidgets.QFileDialog.DontUseNativeDialog, False)
if self.video_masks is not None:
savePath = dlg.getExistingDirectory(
self,
self.tr("选择标签文件保存路径"),
osp.splitext(osp.basename(self.imagePath))[0], )
name, ext = osp.splitext(osp.basename(self.imagePath))
savePath = osp.join(savePath, name)
else:
savePath, _ = dlg.getSaveFileName(
self,
self.tr("选择标签文件保存路径"),
osp.splitext(osp.basename(self.imagePath))[0] + ".png", )
return savePath
def eximgsInit(self):
self.gridTable.setRowCount(0)
self.gridTable.clearContents()
# 清零
self.raster = None
self.grid = None
def setDirty(self, isDirty):
self.isDirty = isDirty
def changeOutputDir(self, outputDir=None):
# 1. 弹框选择标签路径
if outputDir is None:
options = (QtWidgets.QFileDialog.ShowDirsOnly |
QtWidgets.QFileDialog.DontResolveSymlinks)
if self.settings.value("use_qt_widget", False, type=bool):
options = options | QtWidgets.QFileDialog.DontUseNativeDialog
outputDir = QtWidgets.QFileDialog.getExistingDirectory(
self,
self.tr("选择标签保存路径") + " - " + __APPNAME__,
self.settings.value("output_dir", "."),
options, )
if not osp.exists(outputDir):
return False
self.settings.setValue("output_dir", outputDir)
self.outputDir = outputDir
# 2. 加载标签
# 2.1 如果保存coco格式,加载coco标签
if self.save_status["coco"]:
defaultPath = osp.join(self.outputDir, "annotations.json")
if osp.exists(defaultPath):
self.initCoco(defaultPath)
# 2.2 如果保存json格式,获取所有json文件名
if self.save_status["json"]:
labelPaths = os.listdir(outputDir)
labelPaths = [n for n in labelPaths if n.endswith(".json")]
labelPaths = [osp.join(outputDir, n) for n in labelPaths]
self.labelPaths = labelPaths
# 加载对应的标签列表
lab_auto_save = osp.join(self.outputDir, "autosave_label.txt")
if osp.exists(lab_auto_save) == False:
lab_auto_save = osp.join(self.outputDir,
"label/autosave_label.txt")
if osp.exists(lab_auto_save):
try:
self.importLabelList(lab_auto_save)
except:
pass
return True
def maskOpacityChanged(self):
self.sldOpacity.textLab.setText(str(self.opacity))
if not self.controller or self.controller.image is None:
return
for polygon in self.scene.polygon_items:
polygon.setOpacity(self.opacity)
self.updateImage()
if self.video_images is not None and self.video_masks is not None:
self.show_current_frame()
def clickRadiusChanged(self):
self.sldClickRadius.textLab.setText(str(self.clickRadius))
if not self.controller or self.controller.image is None:
return
self.updateImage()
if self.video_images is not None and self.video_masks is not None:
self.show_current_frame()
def threshChanged(self):
self.sldThresh.textLab.setText(str(self.segThresh))
if not self.controller or self.controller.image is None:
return
self.controller.prob_thresh = self.segThresh
self.updateImage()
if self.video_images is not None and self.video_masks is not None:
self.show_current_frame()
# def slideChanged(self):
# self.sldMISlide.textLab.setText(str(self.slideMi))
# if not self.controller or self.controller.image is None:
# return
# self.midx = int(self.slideMi) - 1
# self.miSlideSet()
# self.updateImage()
def undoClick(self):
if self.image is None:
return
if not self.controller:
return
self.controller.undoClick()
self.updateImage()
if not self.controller.is_incomplete_mask:
self.setDirty(False)
def clearAll(self):
if not self.controller or self.controller.image is None:
return
self.controller.resetLastObject()
self.updateImage()
self.setDirty(False)
def redoClick(self):
if self.image is None:
return
if not self.controller:
return
self.controller.redoClick()
self.updateImage()
def canvasClick(self, x, y, isLeft):
c = self.controller
if c.image is None:
return
if not c.inImage(x, y):
return
if not c.modelSet:
self.warn(self.tr("未选择模型", self.tr("尚未选择模型,请先在右上角选择模型")))
return
if self.status == self.IDILE:
return
currLabel = self.controller.curr_label_number
if not currLabel or currLabel == 0:
self.warn(self.tr("未选择当前标签"), self.tr("请先在标签列表中单击点选标签"))
return
self.controller.addClick(x, y, isLeft)
self.updateImage()
self.status = self.ANNING
def updateImage(self, reset_canvas=False):
if not self.controller:
return
image = self.controller.get_visualization(
alpha_blend=self.opacity,
click_radius=self.clickRadius, )
height, width, _ = image.shape
bytesPerLine = 3 * width
image = QImage(image.data, width, height, bytesPerLine,
QImage.Format_RGB888)
if reset_canvas:
self.resetZoom(width, height)
self.annImage.setPixmap(QPixmap(image))
def update_interact_viz(self):
height, width, channel = self.viz.shape
bytesPerLine = 3 * width
qImg = QImage(self.viz.data, width, height, bytesPerLine,
QImage.Format_RGB888)
self.annImage.setPixmap(QPixmap(qImg))
def viewZoomed(self, scale):
self.scene.scale = scale
self.scene.updatePolygonSize()
# 界面缩放重置
def resetZoom(self, width, height):
# 每次加载图像前设定下当前的显示框,解决图像缩小后不在中心的问题
self.scene.setSceneRect(0, 0, width, height)
# 缩放清除
self.canvas.scale(1 / self.canvas.zoom_all,
1 / self.canvas.zoom_all) # 重置缩放
self.canvas.zoom_all = 1
# 最佳缩放
s_eps = 0.98
scr_cont = [
(self.scrollArea.width() * s_eps) / width,
(self.scrollArea.height() * s_eps) / height,
]
if scr_cont[0] * height > self.scrollArea.height():
self.canvas.zoom_all = scr_cont[1]
else:
self.canvas.zoom_all = scr_cont[0]
self.canvas.scale(self.canvas.zoom_all, self.canvas.zoom_all)
self.scene.scale = self.canvas.zoom_all
def keyReleaseEvent(self, event):
# print(event.key(), Qt.Key_Control)
# 释放ctrl的时候刷新图像,对应自适应点大小在缩放后刷新
if not self.controller or self.controller.image is None:
return
if event.key() == Qt.Key_Control:
self.updateImage()
def queueEvent(self, function):
QtCore.QTimer.singleShot(0, function)
def toggleOrigExt(self, dst=None):
if dst:
self.origExt = dst
else:
self.origExt = not self.origExt
self.actions.origional_extension.setChecked(self.origExt)
def toggleAutoSave(self, save):
if save and not self.outputDir:
self.changeOutputDir(None)
if save and not self.outputDir:
save = False
self.actions.auto_save.setChecked(save)
self.settings.setValue("auto_save", save)
def toggleSave(self, type):
self.save_status[type] = not self.save_status[type]
if type == "coco" and self.save_status["coco"]:
self.initCoco()
if type == "coco":
self.save_status["json"] = not self.save_status["coco"]
self.actions.save_json.setChecked(self.save_status["json"])
if type == "json":
self.save_status["coco"] = not self.save_status["json"]
self.actions.save_coco.setChecked(self.save_status["coco"])
def initCoco(self, coco_path: str=None):
if not coco_path:
if not self.outputDir or not osp.exists(self.outputDir):
coco_path = None
else:
coco_path = osp.join(self.outputDir, "annotations.json")
else:
if not osp.exists(coco_path):
coco_path = None
self.coco = COCO(coco_path)
if self.clearLabelList():
self.controller.labelList = util.LabelList(self.coco.dataset[
"categories"])
self.refreshLabelList()
def toggleWidget(self, index=None, warn=True):
# TODO: 输入从数字改成名字
# 1. 改变
if isinstance(index, int):
self.dockStatus[index] = not self.dockStatus[index]
# 2. 判断widget是否可以开启
# 2.1 遥感
if self.dockStatus[4] and not (rs.check_gdal() and rs.check_rasterio()):
if warn:
self.warn(
self.tr("无法导入GDAL或rasterio"),
self.tr("使用遥感工具需要安装GDAL和rasterio!"),
QMessageBox.Yes, )
self.statusbar.showMessage(self.tr("打开遥感工具失败,请安装GDAL和rasterio"))
self.dockStatus[4] = False
# 2.2 医疗
if self.dockStatus[5] and not med.has_sitk():
if warn:
self.warn(
self.tr("无法导入SimpleITK"),
self.tr("使用医疗工具需要安装SimpleITK!"),
QMessageBox.Yes, )
self.statusbar.showMessage(self.tr("打开医疗工具失败,请安装SimpleITK"))
self.dockStatus[5] = False
# 2.3 3D显示
if self.dockStatus[9] and not self.vtkWidget.convert_vtk():
if warn:
self.warn(
self.tr("无法导入VTK"),
self.tr("使用3D显示工具需要安装VTK!"),
QMessageBox.Yes, )
self.statusbar.showMessage(self.tr("打开3D显示工具失败,请安装VTK"))
self.dockStatus[9] = False
widgets = list(self.dockWidgets.values())
for idx, s in enumerate(self.dockStatus):
self.menus.showMenu[idx].setChecked(s)
if s:
for w in widgets[idx]:
w.show()
else:
for w in widgets[idx]:
w.hide()
self.settings.setValue("dock_status", self.dockStatus)
# self.display_dockwidget[index] = bool(self.display_dockwidget[index] - 1)
# self.toggleDockWidgets()
self.saveLayout()
# def toggleDockWidgets(self, is_init=False):
# if is_init == True:
# if self.dockStatus != []:
# if len(self.dockStatus) != len(self.menus.showMenu):
# self.settings.remove("dock_status")
# else:
# self.display_dockwidget = [strtobool(w) for w in self.dockStatus]
# for i in range(len(self.menus.showMenu)):
# self.menus.showMenu[i].setChecked(bool(self.display_dockwidget[i]))
# else:
# self.settings.setValue("dock_status", self.display_dockwidget)
# for t, w in zip(self.display_dockwidget, self.dockWidgets.values()):
# if t == True:
# w.show()
# else:
# w.hide()
def rsBandSet(self, idx):
if self.raster is None:
return
for i in range(len(self.bandCombos)):
self.rsRGB[i] = self.bandCombos[i].currentIndex() + 1 # 从1开始
self.raster.setBand(self.rsRGB)
if self.grid is not None:
if isinstance(self.grid.curr_idx, (list, tuple)):
row, col = self.grid.curr_idx
image, _ = self.raster.getGrid(row, col)
else:
image, _ = self.raster.getArray()
else:
image, _ = self.raster.getArray()
self.image = image
self.controller.image = image
self.updateImage()
# def miSlideSet(self):
# image = rs.slice_img(self.controller.rawImage, self.midx)
# self.test_show(image)
# def changeWorkerShow(self, index):
# self.display_dockwidget[index] = bool(self.display_dockwidget[index] - 1)
# self.toggleDockWidgets()
def updateBandList(self, clean=False):
if clean:
for i in range(len(self.bandCombos)):
try: # 避免打开jpg后再打开tif报错
self.bandCombos[i].currentIndexChanged.disconnect()
except TypeError:
pass
self.bandCombos[i].clear()
self.bandCombos[i].addItems(["band_1"])
return
bands = self.raster.geoinfo.count
for i in range(len(self.bandCombos)):
try: # 避免打开jpg后再打开tif报错
self.bandCombos[i].currentIndexChanged.disconnect()
except TypeError:
pass
self.bandCombos[i].clear()
self.bandCombos[i].addItems(
[("band_" + str(j + 1)) for j in range(bands)])
try:
self.bandCombos[i].setCurrentIndex(self.rsRGB[i] - 1)
except IndexError:
pass
for bandCombo in self.bandCombos:
bandCombo.currentIndexChanged.connect(self.rsBandSet) # 设置波段
# def updateSlideSld(self, clean=False):
# if clean:
# self.sldMISlide.setMaximum(1)
# return
# C = self.controller.rawImage.shape[-1] if len(self.controller.rawImage.shape) == 3 else 1
# self.sldMISlide.setMaximum(C)
def toggleLargestCC(self, on):
try:
self.controller.filterLargestCC(on)
except:
pass
# 宫格标注
def initGrid(self):
self.delAllPolygon()
grid_row_count, grid_col_count = self.grid.createGrids()
self.gridTable.setRowCount(grid_row_count)
self.gridTable.setColumnCount(grid_col_count)
for r in range(grid_row_count):
for c in range(grid_col_count):
self.gridTable.setItem(r, c, QtWidgets.QTableWidgetItem())
self.gridTable.item(r, c).setBackground(self.GRID_COLOR["idle"])
self.gridTable.item(r, c).setFlags(
Qt.ItemIsSelectable) # 无法高亮选择
# 初始显示第一个
self.grid.curr_idx = (0, 0)
self.gridTable.item(0, 0).setBackground(self.GRID_COLOR["overlying"])
# 事件注册
self.gridTable.cellClicked.connect(self.changeGrid)
# load polygon
if self.outputDir is not None:
name = osp.splitext(osp.basename(self.imagePath))[0]
json_path = osp.join(self.outputDir, name + "_grid_saved.json")
if osp.exists(json_path):
self.grid.json_labels = json.loads(open(json_path, "r").read())
# load label
for jlab in self.grid.json_labels:
is_add = True
for label in self.controller.labelList.labelList:
if jlab["labelIdx"] == label.idx and jlab[
"name"] == label.name:
is_add = False
break
if is_add is True:
self.addLabel(jlab["labelIdx"], jlab["name"], jlab["color"])
self.changeGrid(0, 0)
# load mask
for jlab in self.grid.json_labels:
pts = np.int32([np.array(jlab["points"])])
cv2.fillPoly(
self.grid.mask_grids[jlab["row"]][jlab["col"]],
pts=pts,
color=jlab["labelIdx"])
def changeGrid(self, row, col):
def find_in_json(r, c, json_labels):
idxs = []
for idx, json_label in enumerate(json_labels):
if json_label["row"] == r and json_label["col"] == c:
idxs.append(idx)
return idxs
# 清除未保存的切换
self.finishObject()
# TODO: 这块应该通过dirty判断?
if self.grid.curr_idx is not None:
self.saveGrid() # 切换时自动保存上一块
last_r, last_c = self.grid.curr_idx
if self.grid.mask_grids[last_r][last_c] is None:
self.gridTable.item(
last_r, last_c).setBackground(self.GRID_COLOR["idle"])
else:
self.gridTable.item(
last_r, last_c).setBackground(self.GRID_COLOR["finised"])
self.delAllPolygon()
image, _ = self.grid.getGrid(row, col)
self.controller.setImage(image)
self.grid.curr_idx = (row, col)
idxs = find_in_json(row, col, self.grid.json_labels)
if len(idxs) != 0:
# 加载之前的标注
self.gridTable.item(row,
col).setBackground(self.GRID_COLOR["overlying"])
for idx in idxs:
label = self.grid.json_labels[idx]
color = label["color"]
labelIdx = label["labelIdx"]
points = label["points"]
poly = PolygonAnnotation(
labelIdx,
self.controller.image.shape,
self.delPolygon,
self.setDirty,
color,
color,
self.opacity, )
self.scene.addItem(poly)
self.scene.polygon_items.append(poly)
for p in points:
poly.addPointLast(QtCore.QPointF(p[0], p[1]))
[self.grid.json_labels.remove(celement) for \
celement in [self.grid.json_labels[i] for i in idxs]]
else:
self.gridTable.item(row,
col).setBackground(self.GRID_COLOR["current"])
# 刷新
self.updateImage(True)
def saveGrid(self):
row, col = self.grid.curr_idx
if self.grid.curr_idx is None:
return
self.gridTable.item(row,
col).setBackground(self.GRID_COLOR["overlying"])
# if len(np.unique(self.grid.mask_grids[row][col])) == 1:
self.grid.mask_grids[row][col] = np.array(self.getMask())
# save grid label to load
polygons = self.scene.polygon_items
for polygon in polygons:
l = self.controller.labelList[polygon.labelIndex - 1]
label = {
"row": row,
"col": col,
"name": l.name,
"labelIdx": l.idx,
"color": l.color,
"points": [],
}
for p in polygon.scnenePoints:
label["points"].append(p)
self.grid.json_labels.append(label)
# save every blocks or not
if self.cheSaveEvery.isChecked():
_, fullflname = osp.split(self.listFiles.currentItem().text())
fname, _ = os.path.splitext(fullflname)
if self.outputDir is None:
if self.changeOutputDir() is False:
self.cheSaveEvery.setChecked(False)
return
save_ima_path = osp.join(
self.outputDir,
(fname + "_data_" + str(row) + "_" + str(col) + ".tif"))
save_lab_path = osp.join(
self.outputDir,
(fname + "_mask_" + str(row) + "_" + str(col) + ".tif"))
im, tf = self.raster.getGrid(row, col)
h, w = im.shape[:2]
geoinfo = edict()
geoinfo.xsize = w
geoinfo.ysize = h
geoinfo.dtype = self.raster.geoinfo.dtype
geoinfo.crs = self.raster.geoinfo.crs
geoinfo.geotf = tf
self.raster.saveMask(self.grid.mask_grids[row][col], save_lab_path,
geoinfo) # 保存mask
self.raster.saveMask(im, save_ima_path, geoinfo, 3) # 保存图像
def turnGrid(self, delta):
# 切换下一个宫格
r, c = self.grid.curr_idx if self.grid.curr_idx is not None else (0, -1)
c += delta
if c >= self.grid.grid_count[1]:
c = 0
r += 1
if r >= self.grid.grid_count[0]:
r = 0
if c < 0:
c = self.grid.grid_count[1] - 1
r -= 1
if r < 0:
r = self.grid.grid_count[0] - 1
self.changeGrid(r, c)
def closeGrid(self):
self.grid = None
self.gridTable.setRowCount(0)
self.gridTable.clearContents()
def saveGridLabel(self):
if self.grid is None:
return
if self.outputDir is not None:
name, ext = osp.splitext(osp.basename(self.imagePath))
if not self.origExt:
ext = ".png"
save_path = osp.join(self.outputDir, name + ext)
else:
save_path = self.chooseSavePath()
if save_path == "":
return
try:
self.finishObject()
self.saveGrid() # 先保存当前
except:
pass
self.delAllPolygon() # 清理
mask = self.grid.splicingList(save_path)
json_path = save_path.replace(".png", "_grid_saved.json")
open(
json_path, "w",
encoding="utf-8").write(json.dumps(self.grid.json_labels))
if self.grid.__class__.__name__ == "RSGrids":
self.image, geo_tf = self.raster.getArray()
if geo_tf is None:
self.statusbar.showMessage(self.tr("图像过大,已显示缩略图"))
else:
self.image = self.grid.detimg
self.controller.image = self.image
self.controller._result_mask = mask
self.exportLabel(savePath=save_path, lab_input=mask)
# -- RS Show polygon demo --
if self.show_rs_poly.isChecked():
h, w = self.image.shape[:2]
th_mask = cv2.resize(
mask, dsize=(w, h), interpolation=cv2.INTER_NEAREST)
indexs = np.unique(th_mask)[1:]
for i in indexs:
i_mask = np.zeros_like(th_mask, dtype="uint8")
i_mask[th_mask == i] = 255
curr_polygon = util.get_polygon(i_mask)
color = self.controller.labelList[i - 1].color
self.createPoly(curr_polygon, color)
for p in self.scene.polygon_items:
p.setAnning(isAnning=False)
# -- RS Show polygon demo --
# 刷新
grid_row_count = self.gridTable.rowCount()
grid_col_count = self.gridTable.colorCount()
for r in range(grid_row_count):
for c in range(grid_col_count):
try:
self.gridTable.item(
r, c).setBackground(self.GRID_COLOR["idle"])
except:
pass
self.raster = None
self.closeGrid()
self.updateBandList(True)
self.controller.setImage(self.image)
self.updateImage(True)
self.setDirty(False)
@property
def opacity(self):
return self.sldOpacity.value() / 100
@property
def clickRadius(self):
return self.sldClickRadius.value()
@property
def segThresh(self):
return self.sldThresh.value() / 100
# @property
# def slideMi(self):
# return self.sldMISlide.value()
def warnException(self, e):
e = str(e)
title = e.split("。")[0]
self.warn(title, e)
def warn(self, title, text, buttons=QMessageBox.Yes):
msg = QMessageBox()
# msg.setIcon(QMessageBox.Warning)
msg.setWindowTitle(title)
msg.setText(text)
msg.setStandardButtons(buttons)
return msg.exec_()
@property
def status(self):
# TODO: 图片,模型
if not self.controller:
return self.IDILE
c = self.controller
if c.model is None or c.image is None:
return self.IDILE
if self._anning:
return self.ANNING
return self.EDITING
@status.setter
def status(self, status):
if status not in [self.ANNING, self.EDITING]:
return
if status == self.ANNING:
self._anning = True
else:
self._anning = False
def loadGrid(self, img, is_rs=True):
res = self.warn(self.tr("图像过大"), self.tr("图像过大,将启用宫格功能!"), \
buttons=QMessageBox.Yes | QMessageBox.No)
if res == QMessageBox.Yes:
# 打开宫格功能
if self.dockWidgets["grid"][0].isVisible() is False:
# TODO: 改成self.dockStatus
self.menus.showMenu[-1].setChecked(True)
# self.display_dockwidget[-1] = True
self.dockWidgets["grid"][0].show()
self.grid = RSGrids(img) if is_rs else Grids(img)
self.initGrid()
return True
return False
# 界面布局
def loadLayout(self):
self.restoreState(self.layoutStatus)
# TODO: 这里检查环境,判断是不是开医疗和遥感widget
def saveLayout(self):
# 保存界面
self.settings.setValue("layout_status", QByteArray(self.saveState()))
self.settings.setValue(
"save_status",
[(k, self.save_status[k]) for k in self.save_status.keys()])
# # 如果设置了保存路径,把标签也保存下
# if self.outputDir is not None and len(self.controller.labelList) != 0:
# self.exportLabelList(osp.join(self.outputDir, "autosave_label.txt"))
def closeEvent(self, event):
self.saveImage()
self.saveLayout()
QCoreApplication.quit()
# sys.exit(0)
def reportBug(self):
webbrowser.open("https://github.com/PaddlePaddle/PaddleSeg/issues")
def enterEISegMed3D(self):
webbrowser.open(
"https://github.com/PaddlePaddle/PaddleSeg/tree/develop/EISeg/med3d")
def quickStart(self):
# self.saveImage(True)
# self.canvas.setStyleSheet(self.note_style)
webbrowser.open(
"https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/EISeg")
def toggleLogging(self, s):
if s:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.CRITICAL)
self.settings.setValue("log", s)
def toBeImplemented(self):
self.statusbar.showMessage(self.tr("功能尚在开发"))
# 医疗
def wwChanged(self):
if not self.controller or self.image is None:
return
try: # 那种jpg什么格式的医疗图像调整窗宽等会造成崩溃
self.textWw.selectAll()
self.controller.image = med.windowlize(self.controller.rawImage,
self.ww, self.wc)
self.updateImage()
except:
pass
def wcChanged(self):
if not self.controller or self.image is None:
return
try:
self.textWc.selectAll()
self.controller.image = med.windowlize(self.controller.rawImage,
self.ww, self.wc)
self.updateImage()
except:
pass
@property
def ww(self):
return int(self.textWw.text())
@property
def wc(self):
return int(self.textWc.text())
def twwChanged(self):
if self.ww > self.sldWw.maximum():
self.textWw.setText(str(self.sldWw.maximum()))
if self.ww < self.sldWw.minimum():
self.textWw.setText(str(self.sldWw.minimum()))
self.sldWw.setProperty("value", self.ww)
self.wwChanged()
def swwChanged(self):
self.textWw.setText(str(self.sldWw.value()))
self.wwChanged()
def twcChanged(self):
if self.wc > self.sldWc.maximum():
self.textWc.setText(str(self.sldWc.maximum()))
if self.wc < self.sldWc.minimum():
self.textWc.setText(str(self.sldWc.minimum()))
self.sldWc.setProperty("value", self.wc)
self.wcChanged()
def swcChanged(self):
self.textWc.setText(str(self.sldWc.value()))
self.wcChanged()
# 视频
def tframeChanged(self):
if self.video_images is None:
return
if self.video.cursur > self.sldTime.maximum():
self.textTime.setText(str(self.sldTime.maximum()))
if self.video.cursur < self.sldTime.minimum():
self.textTime.setText(str(self.sldTime.minimum()))
self.sldTime.setProperty("value", int(self.textTime.text()))
def sframeChanged(self):
if self.video_images is None:
return
self.textTime.setText(str(self.sldTime.value()))
self.video.cursur = int(self.textTime.text())
self.controller.setImage(self.video_images[self.video.cursur])
self.delAllPolygon()
self.show_current_frame()
# print('current_frame:',self.video.cursur)
def turnPreFrame(self):
if self.video_images is None:
return
self.video.cursur -= 1
if self.video.cursur < 0:
self.video.cursur = self.video.num_frames - 1
self.sldTime.setProperty("value", self.video.cursur)
def turnNextFrame(self):
if self.video_images is None:
return
self.video.cursur += 1
if self.video.cursur > self.video.num_frames - 1:
self.video.cursur = 0
self.sldTime.setProperty("value", self.video.cursur)
def show_current_frame(self):
self.viz = overlay_davis(self.video_images[self.video.cursur],
self.video_masks[self.video.cursur],
self.opacity, self.controller.palette)
self.update_interact_viz()
self.sldTime.setProperty("value", self.video.cursur)
def brushChanged(self):
self.textBrush.setText(str(self.sldBrush.value()))
def on_time(self):
self.video.cursur += 1
if self.video.cursur > self.video.num_frames - 1:
self.video.cursur = 0
self.sldTime.setProperty("value", self.video.cursur)
def on_play(self):
if self.video_images is None:
self.warn(self.tr("图片格式无法播放"), self.tr("请先加载视频"))
return
if self.timer.isActive():
self.timer.stop()
self.videoPlay.setText(self.tr("播放"))
self.videoPlay.setIcon(
QtGui.QIcon(osp.join(pjpath, "resource/Play.png")))
else:
# self.delAllPolygon()
self.timer.start(1000 // self.ratio)
self.videoPlay.setText(self.tr("暂停"))
self.videoPlay.setIcon(
QtGui.QIcon(osp.join(pjpath, "resource/Stop.png")))
def getVideoMask(self):
if self.video_masks is not None:
return self.video_masks[self.video.cursur]
else:
return None
def on_propgation(self):
self.finishObject()
if self.video_images is None:
self.warn(self.tr("未加载视频"), self.tr("请先在加载图像按钮中加载视频"))
return
if self.video.prop_net_segm is None:
self.warn(self.tr("传播模型未加载"), self.tr("尚未加载视频传播模型,请先加载模型!"))
return
if self.video.fuse_net is None:
self.warn(self.tr("融合模型未加载"), self.tr("尚未加载视频融合模型,请先加载模型!"))
return
current_mask = self.getMask()
if current_mask is None:
self.warn(self.tr("未提供传播参考帧"), self.tr("请先在标注传播参考帧再进行传播"))
return
if current_mask.max() == 0:
current_mask = self.video_masks[self.video.cursur]
# self.warn(self.tr("未新增标注"), self.tr("请先添加新标注再进行传播"))
# return
print('-------------start propgation----------------')
self.statusbar.showMessage(self.tr("开始传播"))
# set object
self.video.set_objects(int(max(self.video.k, current_mask.max())))
self.video.set_images(self.video_images)
one_hot_mask = F.one_hot(
paddle.to_tensor(current_mask).astype('int32'),
int(self.video.k + 1))
self.one_hot_mask = one_hot_mask.transpose([2, 0, 1]).unsqueeze(1)
start = time.time()
self.video_masks = self.video.interact(
self.one_hot_mask, self.video.cursur, self.progress_total_cb,
self.progress_step_cb)
end = time.time()
print("propagation time cost", end - start)
self.statusbar.showMessage(self.tr("传播完成!"), 5000)
# 传播进度条重置
self.proPropagete.setValue(0)
self.proPropagete.setFormat('0%')
self.delAllPolygon()
self.show_current_frame()
# 3d显示
color_map = []
for lab in self.controller.labelList:
color_map.append(lab.color)
if self.TDDock.isVisible():
self.vtkWidget.show_array(
np.uint8(self.video_masks), (1., 1., 1.), color_map)
def progress_step_cb(self):
self.progress_num += 1
ratio = self.progress_num / self.progress_max
self.proPropagete.setValue(int(ratio * 100))
self.proPropagete.setFormat('%2.1f%%' % (ratio * 100))
QApplication.processEvents()
def progress_total_cb(self, total):
self.progress_max = total
self.progress_num = -1
self.progress_step_cb()
def useQtWidget(self, s):
self.settings.setValue("use_qt_widget", s)
def checkLabel(self, labelIndex):
for p in self.scene.polygon_items:
if p.labelIndex == labelIndex:
return False
return True
53,119,181
245,128,6
67,159,36
204,43,41
145,104,190
135,86,75
219,120,195
127,127,127
187,189,18
72,190,207
178,199,233
248,187,118
160,222,135
247,153,150
195,176,214
192,156,148
241,183,211
199,199,199
218,219,139
166,218,229
\ No newline at end of file
shortcut:
about: Q
auto_save: X
change_output_dir: Shift+Z
clear: Ctrl+Shift+Z
clear_label: ''
clear_recent: ''
close: Ctrl+W
data_worker: ''
del_active_polygon: Backspace
edit_shortcuts: E
finish_object: Space
grid_ann: ''
label_worker: ''
largest_component: ''
load_label: ''
load_param: Ctrl+M
medical_worker: ''
model_worker: ''
open_folder: Shift+A
open_image: Ctrl+A
origional_extension: ''
quick_start: ''
quit: ''
redo: Ctrl+Y
remote_worker: ''
save: ''
save_as: ''
save_coco: ''
save_json: ''
save_label: ''
save_pseudo: ''
set_worker: ''
turn_next: F
turn_prev: S
undo: Ctrl+Z
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import time
import json
import logging
import cv2
import numpy as np
from skimage.measure import label
import paddle
from eiseg import logger
from inference import clicker
from inference.predictor import get_predictor
import util
from util.vis import draw_with_blend_and_clicks
from models import EISegModel
from util import LabelList
class InteractiveController:
def __init__(
self,
predictor_params: dict=None,
prob_thresh: float=0.5, ):
"""初始化控制器.
Parameters
----------
predictor_params : dict
推理器配置
prob_thresh : float
区分前景和背景结果的阈值
"""
self.predictor_params = predictor_params
self.prob_thresh = prob_thresh
self.model = None
self.image = None
self.rawImage = None
self.predictor = None
self.clicker = clicker.Clicker()
self.states = []
self.probs_history = []
self.polygons = []
# 用于redo
self.undo_states = []
self.undo_probs_history = []
self.curr_label_number = 0
self._result_mask = None
self.labelList = LabelList()
self.lccFilter = False
self.log = logging.getLogger(__name__)
def filterLargestCC(self, do_filter: bool):
"""设置是否只保留推理结果中的最大联通块
Parameters
----------
do_filter : bool
是否只保存推理结果中的最大联通块
"""
if not isinstance(do_filter, bool):
return
self.lccFilter = do_filter
def setModel(self, param_path=None, use_gpu=None):
"""设置推理其模型.
Parameters
----------
params_path : str
模型路径
use_gpu : bool
None:检测,根据paddle版本判断
bool:按照指定是否开启GPU
Returns
-------
bool, str
是否成功设置模型, 失败原因
"""
if param_path is not None:
model_path = param_path.replace(".pdiparams", ".pdmodel")
if not osp.exists(model_path):
raise Exception(f"未在 {model_path} 找到模型文件")
if use_gpu is None:
if paddle.device.is_compiled_with_cuda(
): # TODO: 可以使用GPU却返回False
use_gpu = True
else:
use_gpu = False
logger.info(f"User paddle compiled with gpu: use_gpu {use_gpu}")
tic = time.time()
try:
self.model = EISegModel(model_path, param_path, use_gpu)
self.reset_predictor() # 即刻生效
except KeyError as e:
return False, str(e)
logger.info(f"Load model {model_path} took {time.time() - tic}")
return True, "模型设置成功"
def setImage(self, image: np.array):
"""设置当前标注的图片
Parameters
----------
image : np.array
当前标注的图片
"""
if self.model is not None:
self.image = image
self._result_mask = np.zeros(image.shape[:2], dtype=np.uint8)
self.resetLastObject()
# 标签操作
def setLabelList(self, labelList: json):
"""设置标签列表,会覆盖已有的标签列表
Parameters
----------
labelList : json
标签列表格式为
{
{
"idx" : int (like 0 or 1 or 2)
"name" : str (like "car" or "airplan")
"color" : list (like [255, 0, 0])
},
...
}
Returns
-------
type
Description of returned object.
"""
self.labelList.clear()
labels = json.loads(labelList)
for lab in labels:
self.labelList.add(lab["id"], lab["name"], lab["color"])
def addLabel(self, id: int, name: str, color: list):
self.labelList.add(id, name, color)
def delLabel(self, id: int):
self.labelList.remove(id)
def clearLabel(self):
self.labelList.clear()
def importLabel(self, path):
self.labelList.importLabel(path)
def exportLabel(self, path):
self.labelList.exportLabel(path)
# 点击操作
def addClick(self, x: int, y: int, is_positive: bool):
"""添加一个点并运行推理,保存历史用于undo
Parameters
----------
x : int
点击的横坐标
y : int
点击的纵坐标
is_positive : bool
是否点的是正点
Returns
-------
bool, str
点击是否添加成功, 失败原因
"""
# 1. 确定可以点
if not self.inImage(x, y):
return False, "点击越界"
if not self.modelSet:
return False, "未加载模型"
if not self.imageSet:
return False, "图像未设置"
if len(self.states) == 0: # 保存一个空状态
self.states.append({
"clicker": self.clicker.get_state(),
"predictor": self.predictor.get_states(),
})
# 2. 添加点击,跑推理
click = clicker.Click(is_positive=is_positive, coords=(y, x))
self.clicker.add_click(click)
pred = self.predictor.get_prediction(self.clicker)
# 3. 保存状态
self.states.append({
"clicker": self.clicker.get_state(),
"predictor": self.predictor.get_states(),
})
if self.probs_history:
self.probs_history.append((self.probs_history[-1][1], pred))
else:
self.probs_history.append((np.zeros_like(pred), pred))
# 点击之后就不能接着之前的历史redo了
self.undo_states = []
self.undo_probs_history = []
return True, "点击添加成功"
def undoClick(self):
"""
undo一步点击
"""
if len(self.states) <= 1: # == 1就只剩下一个空状态了,不用再退
return
self.undo_states.append(self.states.pop())
self.clicker.set_state(self.states[-1]["clicker"])
self.predictor.set_states(self.states[-1]["predictor"])
self.undo_probs_history.append(self.probs_history.pop())
if not self.probs_history:
self.reset_init_mask()
def redoClick(self):
"""
redo一步点击
"""
if len(self.undo_states) == 0: # 如果还没撤销过
return
if len(self.undo_probs_history) >= 1:
next_state = self.undo_states.pop()
self.states.append(next_state)
self.clicker.set_state(next_state["clicker"])
self.predictor.set_states(next_state["predictor"])
self.probs_history.append(self.undo_probs_history.pop())
def finishObject(self, building=False):
"""
结束当前物体标注,准备标下一个
"""
object_prob = self.current_object_prob
if object_prob is None:
return None, None
object_mask = object_prob > self.prob_thresh
if self.lccFilter:
object_mask = self.getLargestCC(object_mask)
polygon = util.get_polygon(
(object_mask.astype(np.uint8) * 255),
img_size=object_mask.shape,
building=building)
if polygon is not None:
self._result_mask[object_mask] = self.curr_label_number
self.resetLastObject()
self.polygons.append([self.curr_label_number, polygon])
return object_mask, polygon
# 多边形
def getPolygon(self):
return self.polygon
def setPolygon(self, polygon):
self.polygon = polygon
# mask
def getMask(self):
s = self.imgShape
img = np.zeros([s[0], s[1]])
for poly in self.polygons:
pts = np.int32([np.array(poly[1])])
cv2.fillPoly(img, pts=pts, color=poly[0])
return img
def setCurrLabelIdx(self, number):
if not isinstance(number, int):
return False
self.curr_label_number = number
def resetLastObject(self, update_image=True):
"""
重置控制器状态
Parameters
update_image(bool): 是否更新图像
"""
self.states = []
self.probs_history = []
self.undo_states = []
self.undo_probs_history = []
# self.current_object_prob = None
self.clicker.reset_clicks()
self.reset_predictor()
self.reset_init_mask()
def reset_predictor(self, predictor_params=None):
"""
重置推理器,可以换推理配置
Parameters
predictor_params(dict): 推理配置
"""
if predictor_params is not None:
self.predictor_params = predictor_params
if self.model.model:
self.predictor = get_predictor(self.model.model,
**self.predictor_params)
if self.image is not None:
self.predictor.set_input_image(self.image)
def reset_init_mask(self):
self.clicker.click_indx_offset = 0
def getLargestCC(self, mask):
mask = label(mask)
if mask.max() == 0:
return mask
mask = mask == np.argmax(np.bincount(mask.flat)[1:]) + 1
return mask
def get_visualization(self, alpha_blend: float, click_radius: int):
if self.image is None:
return None
# 1. 正在标注的mask
# results_mask_for_vis = self.result_mask # 加入之前标完的mask
results_mask_for_vis = np.zeros_like(self.result_mask)
results_mask_for_vis *= self.curr_label_number
if self.probs_history:
results_mask_for_vis[self.current_object_prob >
self.prob_thresh] = self.curr_label_number
if self.lccFilter:
results_mask_for_vis = (self.getLargestCC(results_mask_for_vis) *
self.curr_label_number)
vis = draw_with_blend_and_clicks(
self.image,
mask=results_mask_for_vis,
alpha=alpha_blend,
clicks_list=self.clicker.clicks_list,
radius=click_radius,
palette=self.palette, )
return vis
def inImage(self, x: int, y: int):
s = self.image.shape
if x < 0 or y < 0 or x >= s[1] or y >= s[0]:
return False
return True
@property
def result_mask(self):
result_mask = self._result_mask.copy()
return result_mask
@property
def palette(self):
if self.labelList:
colors = [ml.color for ml in self.labelList]
colors.insert(0, [0, 0, 0])
else:
colors = [[0, 0, 0]]
return colors
@property
def current_object_prob(self):
"""
获取当前推理标签
"""
if self.probs_history:
_, current_prob_additive = self.probs_history[-1]
return current_prob_additive
else:
return None
@property
def is_incomplete_mask(self):
"""
Returns
bool: 当前的物体是不是还没标完
"""
return len(self.probs_history) > 0
@property
def imgShape(self):
return self.image.shape # [1::-1]
@property
def modelSet(self):
return self.model is not None
@property
def modelName(self):
return self.model.__name__
@property
def imageSet(self):
return self.image is not None
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import sys
sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
from run import main
if __name__ == "__main__":
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import cv2
import numpy as np
from copy import deepcopy
class Clicker(object):
def __init__(self,
gt_mask=None,
init_clicks=None,
ignore_label=-1,
click_indx_offset=0):
self.click_indx_offset = click_indx_offset
if gt_mask is not None:
self.gt_mask = gt_mask == 1
self.not_ignore_mask = gt_mask != ignore_label
else:
self.gt_mask = None
self.reset_clicks()
if init_clicks is not None:
for click in init_clicks:
self.add_click(click)
def make_next_click(self, pred_mask):
assert self.gt_mask is not None
click = self._get_next_click(pred_mask)
self.add_click(click)
def get_clicks(self, clicks_limit=None):
return self.clicks_list[:clicks_limit]
def _get_next_click(self, pred_mask, padding=True):
fn_mask = np.logical_and(
np.logical_and(self.gt_mask, np.logical_not(pred_mask)),
self.not_ignore_mask, )
fp_mask = np.logical_and(
np.logical_and(np.logical_not(self.gt_mask), pred_mask),
self.not_ignore_mask, )
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
fn_mask_dt = cv2.distanceTransform(
fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(
fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
fn_mask_dt = fn_mask_dt * self.not_clicked_map
fp_mask_dt = fp_mask_dt * self.not_clicked_map
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
if is_positive:
coords_y, coords_x = np.where(
fn_mask_dt == fn_max_dist) # coords is [y, x]
else:
coords_y, coords_x = np.where(
fp_mask_dt == fp_max_dist) # coords is [y, x]
return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
def add_click(self, click):
coords = click.coords
click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
if click.is_positive:
self.num_pos_clicks += 1
else:
self.num_neg_clicks += 1
self.clicks_list.append(click)
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = False
def _remove_last_click(self):
click = self.clicks_list.pop()
coords = click.coords
if click.is_positive:
self.num_pos_clicks -= 1
else:
self.num_neg_clicks -= 1
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = True
def reset_clicks(self):
if self.gt_mask is not None:
self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
self.num_pos_clicks = 0
self.num_neg_clicks = 0
self.clicks_list = []
def get_state(self):
return deepcopy(self.clicks_list)
def set_state(self, state):
self.reset_clicks()
for click in state:
self.add_click(click)
def __len__(self):
return len(self.clicks_list)
class Click:
def __init__(self, is_positive, coords, indx=None):
self.is_positive = is_positive
self.coords = coords
self.indx = indx
@property
def coords_and_indx(self):
return (*self.coords, self.indx)
def copy(self, **kwargs):
self_copy = deepcopy(self)
for k, v in kwargs.items():
setattr(self_copy, k, v)
return self_copy
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
from .base import BasePredictor
from inference.transforms import ZoomIn
def get_predictor(net,
brs_mode,
with_flip=False,
zoom_in_params=dict(),
predictor_params=None):
predictor_params_ = {"optimize_after_n_clicks": 1}
if zoom_in_params is not None:
zoom_in = ZoomIn(**zoom_in_params)
else:
zoom_in = None
if brs_mode == "NoBRS":
if predictor_params is not None:
predictor_params_.update(predictor_params)
predictor = BasePredictor(
net, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
else:
raise NotImplementedError("Just support NoBRS mode")
return predictor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import paddle.nn.functional as F
import numpy as np
from inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
from .ops import DistMaps, ScaleLayer, BatchImageNormalize
class BasePredictor(object):
def __init__(self,
model,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
with_mask=True,
**kwargs):
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.zoom_in = zoom_in
self.prev_prediction = None
self.model_indx = 0
self.click_models = None
self.net_state_dict = None
self.with_prev_mask = with_mask
self.net = model
if not paddle.in_dynamic_mode():
paddle.disable_static()
self.normalization = BatchImageNormalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
self.dist_maps = DistMaps(
norm_radius=5, spatial_scale=1.0, cpu_mode=False, use_disks=True)
def to_tensor(self, x):
if isinstance(x, np.ndarray):
if x.ndim == 2:
x = x[:, :, None]
img = paddle.to_tensor(x.transpose([2, 0, 1])).astype("float32") / 255
return img
def set_input_image(self, image):
image_nd = self.to_tensor(image)
for transform in self.transforms:
transform.reset()
self.original_image = image_nd
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
self.prev_prediction = paddle.zeros_like(self.original_image[:, :
1, :, :])
if not self.with_prev_mask:
self.prev_edge = paddle.zeros_like(self.original_image[:, :1, :, :])
def get_prediction(self, clicker, prev_mask=None):
clicks_list = clicker.get_clicks()
input_image = self.original_image
if prev_mask is None:
if not self.with_prev_mask:
prev_mask = self.prev_edge
else:
prev_mask = self.prev_prediction
input_image = paddle.concat([input_image, prev_mask], axis=1)
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list])
pred_logits, pred_edges = self._get_prediction(image_nd, clicks_lists,
is_image_changed)
pred_logits = paddle.to_tensor(pred_logits)
prediction = F.interpolate(
pred_logits,
mode="bilinear",
align_corners=True,
size=image_nd.shape[2:])
if pred_edges is not None:
pred_edge = paddle.to_tensor(pred_edges)
edge_prediction = F.interpolate(
pred_edge,
mode="bilinear",
align_corners=True,
size=image_nd.shape[2:])
for t in reversed(self.transforms):
if pred_edges is not None:
edge_prediction = t.inv_transform(edge_prediction)
self.prev_edge = edge_prediction
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(
):
return self.get_prediction(clicker)
self.prev_prediction = prediction
return prediction.numpy()[0, 0]
def prepare_input(self, image):
prev_mask = None
prev_mask = image[:, 3:, :, :]
image = image[:, :3, :, :]
image = self.normalization(image)
return image, prev_mask
def get_coord_features(self, image, prev_mask, points):
coord_features = self.dist_maps(image, points)
if prev_mask is not None:
coord_features = paddle.concat((prev_mask, coord_features), axis=1)
return coord_features
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
input_names = self.net.get_input_names()
self.input_handle_1 = self.net.get_input_handle(input_names[0])
self.input_handle_2 = self.net.get_input_handle(input_names[1])
points_nd = self.get_points_nd(clicks_lists)
image, prev_mask = self.prepare_input(image_nd)
coord_features = self.get_coord_features(image, prev_mask, points_nd)
image = image.numpy().astype("float32")
coord_features = coord_features.numpy().astype("float32")
self.input_handle_1.copy_from_cpu(image)
self.input_handle_2.copy_from_cpu(coord_features)
self.net.run()
output_names = self.net.get_output_names()
output_handle = self.net.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
if len(output_names) == 3:
edge_handle = self.net.get_output_handle(output_names[2])
edge_data = edge_handle.copy_to_cpu()
return output_data, edge_data
else:
return output_data, None
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [
sum(x.is_positive for x in clicks_list)
for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list
if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [
click.coords_and_indx for click in clicks_list
if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks)
return paddle.to_tensor(total_clicks)
def get_states(self):
return {
"transform_states": self._get_transform_states(),
"prev_prediction": self.prev_prediction,
}
def set_states(self, states):
self._set_transform_states(states["transform_states"])
self.prev_prediction = states["prev_prediction"]
def split_points_by_order(tpoints, groups):
points = tpoints.numpy()
num_groups = len(groups)
bs = points.shape[0]
num_points = points.shape[1] // 2
groups = [x if x > 0 else num_points for x in groups]
group_points = [
np.full(
(bs, 2 * x, 3), -1, dtype=np.float32) for x in groups
]
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
for group_indx, group_size in enumerate(groups):
last_point_indx_group[:, group_indx, 1] = group_size
for bindx in range(bs):
for pindx in range(2 * num_points):
point = points[bindx, pindx, :]
group_id = int(point[2])
if group_id < 0:
continue
is_negative = int(pindx >= num_points)
if group_id >= num_groups or (
group_id == 0 and
is_negative): # disable negative first click
group_id = num_groups - 1
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
last_point_indx_group[bindx, group_id, is_negative] += 1
group_points[group_id][bindx, new_point_indx, :] = point
group_points = [
paddle.to_tensor(
x, dtype=tpoints.dtype) for x in group_points
]
return group_points
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import paddle.nn as nn
import numpy as np
class DistMaps(nn.Layer):
def __init__(self,
norm_radius,
spatial_scale=1.0,
cpu_mode=True,
use_disks=False):
super(DistMaps, self).__init__()
self.spatial_scale = spatial_scale
self.norm_radius = norm_radius
self.cpu_mode = cpu_mode
self.use_disks = use_disks
if self.cpu_mode:
from util.cython import get_dist_maps
self._get_dist_maps = get_dist_maps
def get_coord_features(self, points, batchsize, rows, cols):
if self.cpu_mode:
coords = []
for i in range(batchsize):
norm_delimeter = (1.0 if self.use_disks else
self.spatial_scale * self.norm_radius)
coords.append(
self._get_dist_maps(points[i].numpy().astype("float32"),
rows, cols, norm_delimeter))
coords = paddle.to_tensor(np.stack(
coords, axis=0)).astype("float32")
else:
num_points = points.shape[1] // 2
points = points.reshape([-1, points.shape[2]])
points, points_order = paddle.split(points, [2, 1], axis=1)
invalid_points = paddle.max(points, axis=1, keepdim=False) < 0
row_array = paddle.arange(
start=0, end=rows, step=1, dtype="float32")
col_array = paddle.arange(
start=0, end=cols, step=1, dtype="float32")
coord_rows, coord_cols = paddle.meshgrid(row_array, col_array)
coords = paddle.unsqueeze(
paddle.stack(
[coord_rows, coord_cols], axis=0),
axis=0).tile([points.shape[0], 1, 1, 1])
add_xy = (points * self.spatial_scale).reshape(
[points.shape[0], points.shape[1], 1, 1])
coords = coords - add_xy
if not self.use_disks:
coords = coords / (self.norm_radius * self.spatial_scale)
coords = coords * coords
coords[:, 0] += coords[:, 1]
coords = coords[:, :1]
invalid_points = invalid_points.numpy()
coords[invalid_points, :, :, :] = 1e6
coords = coords.reshape([-1, num_points, 1, rows, cols])
coords = paddle.min(coords, axis=1)
coords = coords.reshape([-1, 2, rows, cols])
if self.use_disks:
coords = (
coords <=
(self.norm_radius * self.spatial_scale)**2).astype("float32")
else:
coords = paddle.tanh(paddle.sqrt(coords) * 2)
return coords
def forward(self, x, coords):
return self.get_coord_features(coords, x.shape[0], x.shape[2],
x.shape[3])
class ScaleLayer(nn.Layer):
def __init__(self, init_value=1.0, lr_mult=1):
super().__init__()
self.lr_mult = lr_mult
self.scale = self.create_parameter(
shape=[1],
dtype="float32",
default_initializer=nn.initializer.Constant(init_value / lr_mult), )
def forward(self, x):
scale = paddle.abs(self.scale * self.lr_mult)
return x * scale
class BatchImageNormalize:
def __init__(self, mean, std):
self.mean = paddle.to_tensor(
np.array(mean)[np.newaxis, :, np.newaxis, np.newaxis]).astype(
"float32")
self.std = paddle.to_tensor(
np.array(std)[np.newaxis, :, np.newaxis, np.newaxis]).astype(
"float32")
def __call__(self, tensor):
tensor = (tensor - self.mean) / self.std
return tensor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import SigmoidForPred
from .flip import AddHorizontalFlip
from .zoom_in import ZoomIn
from .limit_longest_side import LimitLongestSide
from .crops import Crops
import paddle.nn.functional as F
class BaseTransform(object):
def __init__(self):
self.image_changed = False
def transform(self, image_nd, clicks_lists):
raise NotImplementedError
def inv_transform(self, prob_map):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
class SigmoidForPred(BaseTransform):
def transform(self, image_nd, clicks_lists):
return image_nd, clicks_lists
def inv_transform(self, prob_map):
return F.sigmoid(prob_map)
def reset(self):
pass
def get_state(self):
return None
def set_state(self, state):
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import math
import paddle
import numpy as np
from inference.clicker import Click
from .base import BaseTransform
class Crops(BaseTransform):
def __init__(self, crop_size=(320, 480), min_overlap=0.2):
super().__init__()
self.crop_height, self.crop_width = crop_size
self.min_overlap = min_overlap
self.x_offsets = None
self.y_offsets = None
self._counts = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_height, image_width = image_nd.shape[2:4]
self._counts = None
if image_height < self.crop_height or image_width < self.crop_width:
return image_nd, clicks_lists
self.x_offsets = get_offsets(image_width, self.crop_width,
self.min_overlap)
self.y_offsets = get_offsets(image_height, self.crop_height,
self.min_overlap)
self._counts = np.zeros((image_height, image_width))
image_crops = []
for dy in self.y_offsets:
for dx in self.x_offsets:
self._counts[dy:dy + self.crop_height, dx:dx +
self.crop_width] += 1
image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx +
self.crop_width]
image_crops.append(image_crop)
image_crops = paddle.concat(image_crops, axis=0)
self._counts = paddle.to_tensor(self._counts, dtype="float32")
clicks_list = clicks_lists[0]
clicks_lists = []
for dy in self.y_offsets:
for dx in self.x_offsets:
crop_clicks = [
x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx))
for x in clicks_list
]
clicks_lists.append(crop_clicks)
return image_crops, clicks_lists
def inv_transform(self, prob_map):
if self._counts is None:
return prob_map
new_prob_map = paddle.zeros(
(1, 1, *self._counts.shape), dtype=prob_map.dtype)
crop_indx = 0
for dy in self.y_offsets:
for dx in self.x_offsets:
new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx +
self.crop_width] += prob_map[crop_indx, 0]
crop_indx += 1
new_prob_map = paddle.divide(new_prob_map, self._counts)
return new_prob_map
def get_state(self):
return self.x_offsets, self.y_offsets, self._counts
def set_state(self, state):
self.x_offsets, self.y_offsets, self._counts = state
def reset(self):
self.x_offsets = None
self.y_offsets = None
self._counts = None
def get_offsets(length, crop_size, min_overlap_ratio=0.2):
if length == crop_size:
return [0]
N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
N = math.ceil(N)
overlap_ratio = (N - length / crop_size) / (N - 1)
overlap_width = int(crop_size * overlap_ratio)
offsets = [0]
for i in range(1, N):
new_offset = offsets[-1] + crop_size - overlap_width
if new_offset + crop_size > length:
new_offset = length - crop_size
offsets.append(new_offset)
return offsets
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
from inference.clicker import Click
from .base import BaseTransform
class AddHorizontalFlip(BaseTransform):
def transform(self, image_nd, clicks_lists):
assert len(image_nd.shape) == 4
image_nd = paddle.concat(
[image_nd, paddle.flip(
image_nd, axis=[3])], axis=0)
image_width = image_nd.shape[3]
clicks_lists_flipped = []
for clicks_list in clicks_lists:
clicks_list_flipped = [
click.copy(coords=(click.coords[0],
image_width - click.coords[1] - 1))
for click in clicks_list
]
clicks_lists_flipped.append(clicks_list_flipped)
clicks_lists = clicks_lists + clicks_lists_flipped
return image_nd, clicks_lists
def inv_transform(self, prob_map):
assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
num_maps = prob_map.shape[0] // 2
prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
return 0.5 * (prob_map + paddle.flip(prob_map_flipped, axis=[3]))
def get_state(self):
return None
def set_state(self, state):
pass
def reset(self):
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
from .zoom_in import ZoomIn, get_roi_image_nd
class LimitLongestSide(ZoomIn):
def __init__(self, max_size=800):
super().__init__(target_size=max_size, skip_clicks=0)
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_max_size = max(image_nd.shape[2:4])
self.image_changed = False
if image_max_size <= self.target_size:
return image_nd, clicks_lists
self._input_image = image_nd
self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
self._roi_image = get_roi_image_nd(image_nd, self._object_roi,
self.target_size)
self.image_changed = True
tclicks_lists = [self._transform_clicks(clicks_lists[0])]
return self._roi_image, tclicks_lists
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import numpy as np
from inference.clicker import Click
from util.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
from .base import BaseTransform
class ZoomIn(BaseTransform):
def __init__(
self,
target_size=700,
skip_clicks=1,
expansion_ratio=1.4,
min_crop_size=480,
recompute_thresh_iou=0.5,
prob_thresh=0.50, ):
super().__init__()
self.target_size = target_size
self.min_crop_size = min_crop_size
self.skip_clicks = skip_clicks
self.expansion_ratio = expansion_ratio
self.recompute_thresh_iou = recompute_thresh_iou
self.prob_thresh = prob_thresh
self._input_image_shape = None
self._prev_probs = None
self._object_roi = None
self._roi_image = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
self.image_changed = False
clicks_list = clicks_lists[0]
if len(clicks_list) <= self.skip_clicks:
return image_nd, clicks_lists
self._input_image_shape = image_nd.shape
current_object_roi = None
if self._prev_probs is not None:
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if current_pred_mask.sum() > 0:
current_object_roi = get_object_roi(
current_pred_mask,
clicks_list,
self.expansion_ratio,
self.min_crop_size, )
if current_object_roi is None:
if self.skip_clicks >= 0:
return image_nd, clicks_lists
else:
current_object_roi = 0, image_nd.shape[
2] - 1, 0, image_nd.shape[3] - 1
update_object_roi = False
if self._object_roi is None:
update_object_roi = True
elif not check_object_roi(self._object_roi, clicks_list):
update_object_roi = True
elif (get_bbox_iou(current_object_roi, self._object_roi) <
self.recompute_thresh_iou):
update_object_roi = True
if update_object_roi:
self._object_roi = current_object_roi
self.image_changed = True
self._roi_image = get_roi_image_nd(image_nd, self._object_roi,
self.target_size)
tclicks_lists = [self._transform_clicks(clicks_list)]
return self._roi_image, tclicks_lists
def inv_transform(self, prob_map):
if self._object_roi is None:
self._prev_probs = prob_map.numpy()
return prob_map
assert prob_map.shape[0] == 1
rmin, rmax, cmin, cmax = self._object_roi
prob_map = paddle.nn.functional.interpolate(
prob_map,
size=(rmax - rmin + 1, cmax - cmin + 1),
mode="bilinear",
align_corners=True, )
if self._prev_probs is not None:
new_prob_map = paddle.zeros(
shape=self._prev_probs.shape, dtype=prob_map.dtype)
new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
else:
new_prob_map = prob_map
self._prev_probs = new_prob_map.numpy()
return new_prob_map
def check_possible_recalculation(self):
if (self._prev_probs is None or self._object_roi is not None or
self.skip_clicks > 0):
return False
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if pred_mask.sum() > 0:
possible_object_roi = get_object_roi(
pred_mask, [], self.expansion_ratio, self.min_crop_size)
image_roi = (
0,
self._input_image_shape[2] - 1,
0,
self._input_image_shape[3] - 1, )
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
return True
return False
def get_state(self):
roi_image = self._roi_image if self._roi_image is not None else None
return (
self._input_image_shape,
self._object_roi,
self._prev_probs,
roi_image,
self.image_changed, )
def set_state(self, state):
(
self._input_image_shape,
self._object_roi,
self._prev_probs,
self._roi_image,
self.image_changed, ) = state
def reset(self):
self._input_image_shape = None
self._object_roi = None
self._prev_probs = None
self._roi_image = None
self.image_changed = False
def _transform_clicks(self, clicks_list):
if self._object_roi is None:
return clicks_list
rmin, rmax, cmin, cmax = self._object_roi
crop_height, crop_width = self._roi_image.shape[2:]
transformed_clicks = []
for click in clicks_list:
new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
transformed_clicks.append(click.copy(coords=(new_r, new_c)))
return transformed_clicks
def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
pred_mask = pred_mask.copy()
for click in clicks_list:
if click.is_positive:
pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
bbox = get_bbox_from_mask(pred_mask)
bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
h, w = pred_mask.shape[0], pred_mask.shape[1]
bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
return bbox
def get_roi_image_nd(image_nd, object_roi, target_size):
rmin, rmax, cmin, cmax = object_roi
height = rmax - rmin + 1
width = cmax - cmin + 1
if isinstance(target_size, tuple):
new_height, new_width = target_size
else:
scale = target_size / max(height, width)
new_height = int(round(height * scale))
new_width = int(round(width * scale))
with paddle.no_grad():
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
roi_image_nd = paddle.nn.functional.interpolate(
roi_image_nd,
size=(new_height, new_width),
mode="bilinear",
align_corners=True, )
return roi_image_nd
def check_object_roi(object_roi, clicks_list):
for click in clicks_list:
if click.is_positive:
if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[
1]:
return False
if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[
3]:
return False
return True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
from abc import abstractmethod
import paddle.inference as paddle_infer
here = osp.dirname(osp.abspath(__file__))
class EISegModel:
@abstractmethod
def __init__(self, model_path, param_path, use_gpu=False):
model_path, param_path = self.check_param(model_path, param_path)
try:
config = paddle_infer.Config(model_path, param_path)
except:
ValueError(" 模型和参数不匹配,请检查模型和参数是否加载错误")
if not use_gpu:
config.enable_mkldnn()
# TODO: fluid要废弃了,研究判断方式
# if paddle.fluid.core.supports_bfloat16():
# config.enable_mkldnn_bfloat16()
config.switch_ir_optim(True)
config.set_cpu_math_library_num_threads(10)
else:
config.enable_use_gpu(500, 0)
config.delete_pass("conv_elementwise_add_act_fuse_pass")
config.delete_pass("conv_elementwise_add2_act_fuse_pass")
config.delete_pass("conv_elementwise_add_fuse_pass")
config.switch_ir_optim()
config.enable_memory_optim()
# use_tensoret = False # TODO: 目前Linux和windows下使用TensorRT报错
# if use_tensoret:
# config.enable_tensorrt_engine(
# workspace_size=1 << 30,
# precision_mode=paddle_infer.PrecisionType.Float32,
# max_batch_size=1,
# min_subgraph_size=5,
# use_static=False,
# use_calib_mode=False,
# )
self.model = paddle_infer.create_predictor(config)
def check_param(self, model_path, param_path):
if model_path is None or not osp.exists(model_path):
raise Exception(f"模型路径{model_path}不存在。请指定正确的模型路径")
if param_path is None or not osp.exists(param_path):
raise Exception(f"权重路径{param_path}不存在。请指定正确的权重路径")
return model_path, param_path
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .med import has_sitk, dcm_reader, windowlize
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
from eiseg import logger
def has_sitk():
try:
import SimpleITK
return True
except ImportError:
return False
if has_sitk():
import SimpleITK as sitk
def dcm_reader(path):
logger.debug(f"opening medical image {path}")
reader = sitk.ImageSeriesReader()
reader.SetFileNames([path])
image = reader.Execute()
img = sitk.GetArrayFromImage(image)
logger.debug(f"scan shape is {img.shape}")
if len(img.shape) == 4:
img = img[0]
# WHC
img = np.transpose(img, [1, 2, 0])
return img.astype(np.int32)
def windowlize(scan, ww, wc):
wl = wc - ww / 2
wh = wc + ww / 2
res = scan.copy()
res = res.astype(np.float32)
res = np.clip(res, wl, wh)
res = (res - wl) / ww * 255
res = res.astype(np.uint8)
# print("++", res.shape)
# for idx in range(res.shape[-1]):
# TODO: 支持3d或者改调用
res = cv2.cvtColor(res, cv2.COLOR_GRAY2BGR)
return res
# def open_nii(niiimg_path):
# if IPT_SITK == True:
# sitk_image = sitk.ReadImage(niiimg_path)
# return _nii2arr(sitk_image)
# else:
# raise ImportError("can't import SimpleITK!")
#
# def _nii2arr(sitk_image):
# if IPT_SITK == True:
# img = sitk.GetArrayFromImage(sitk_image).transpose((1, 2, 0))
# return img
# else:
# raise ImportError("can't import SimpleITK!")
#
#
# def slice_img(img, index):
# if index == 0:
# return sample_norm(
# cv2.merge(
# [
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index + 1]),
# ]
# )
# )
# elif index == img.shape[2] - 1:
# return sample_norm(
# cv2.merge(
# [
# np.uint16(img[:, :, index - 1]),
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index]),
# ]
# )
# )
# else:
# return sample_norm(
# cv2.merge(
# [
# np.uint16(img[:, :, index - 1]),
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index + 1]),
# ]
# )
# )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment