Commit 0d97cc8c authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
Pipeline #316 failed with stages
in 0 seconds
This diff is collapsed.
53,119,181
245,128,6
67,159,36
204,43,41
145,104,190
135,86,75
219,120,195
127,127,127
187,189,18
72,190,207
178,199,233
248,187,118
160,222,135
247,153,150
195,176,214
192,156,148
241,183,211
199,199,199
218,219,139
166,218,229
\ No newline at end of file
shortcut:
about: Q
auto_save: X
change_output_dir: Shift+Z
clear: Ctrl+Shift+Z
clear_label: ''
clear_recent: ''
close: Ctrl+W
data_worker: ''
del_active_polygon: Backspace
edit_shortcuts: E
finish_object: Space
grid_ann: ''
label_worker: ''
largest_component: ''
load_label: ''
load_param: Ctrl+M
medical_worker: ''
model_worker: ''
open_folder: Shift+A
open_image: Ctrl+A
origional_extension: ''
quick_start: ''
quit: ''
redo: Ctrl+Y
remote_worker: ''
save: ''
save_as: ''
save_coco: ''
save_json: ''
save_label: ''
save_pseudo: ''
set_worker: ''
turn_next: F
turn_prev: S
undo: Ctrl+Z
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import time
import json
import logging
import cv2
import numpy as np
from skimage.measure import label
import paddle
from eiseg import logger
from inference import clicker
from inference.predictor import get_predictor
import util
from util.vis import draw_with_blend_and_clicks
from models import EISegModel
from util import LabelList
class InteractiveController:
def __init__(
self,
predictor_params: dict=None,
prob_thresh: float=0.5, ):
"""初始化控制器.
Parameters
----------
predictor_params : dict
推理器配置
prob_thresh : float
区分前景和背景结果的阈值
"""
self.predictor_params = predictor_params
self.prob_thresh = prob_thresh
self.model = None
self.image = None
self.rawImage = None
self.predictor = None
self.clicker = clicker.Clicker()
self.states = []
self.probs_history = []
self.polygons = []
# 用于redo
self.undo_states = []
self.undo_probs_history = []
self.curr_label_number = 0
self._result_mask = None
self.labelList = LabelList()
self.lccFilter = False
self.log = logging.getLogger(__name__)
def filterLargestCC(self, do_filter: bool):
"""设置是否只保留推理结果中的最大联通块
Parameters
----------
do_filter : bool
是否只保存推理结果中的最大联通块
"""
if not isinstance(do_filter, bool):
return
self.lccFilter = do_filter
def setModel(self, param_path=None, use_gpu=None):
"""设置推理其模型.
Parameters
----------
params_path : str
模型路径
use_gpu : bool
None:检测,根据paddle版本判断
bool:按照指定是否开启GPU
Returns
-------
bool, str
是否成功设置模型, 失败原因
"""
if param_path is not None:
model_path = param_path.replace(".pdiparams", ".pdmodel")
if not osp.exists(model_path):
raise Exception(f"未在 {model_path} 找到模型文件")
if use_gpu is None:
if paddle.device.is_compiled_with_cuda(
): # TODO: 可以使用GPU却返回False
use_gpu = True
else:
use_gpu = False
logger.info(f"User paddle compiled with gpu: use_gpu {use_gpu}")
tic = time.time()
try:
self.model = EISegModel(model_path, param_path, use_gpu)
self.reset_predictor() # 即刻生效
except KeyError as e:
return False, str(e)
logger.info(f"Load model {model_path} took {time.time() - tic}")
return True, "模型设置成功"
def setImage(self, image: np.array):
"""设置当前标注的图片
Parameters
----------
image : np.array
当前标注的图片
"""
if self.model is not None:
self.image = image
self._result_mask = np.zeros(image.shape[:2], dtype=np.uint8)
self.resetLastObject()
# 标签操作
def setLabelList(self, labelList: json):
"""设置标签列表,会覆盖已有的标签列表
Parameters
----------
labelList : json
标签列表格式为
{
{
"idx" : int (like 0 or 1 or 2)
"name" : str (like "car" or "airplan")
"color" : list (like [255, 0, 0])
},
...
}
Returns
-------
type
Description of returned object.
"""
self.labelList.clear()
labels = json.loads(labelList)
for lab in labels:
self.labelList.add(lab["id"], lab["name"], lab["color"])
def addLabel(self, id: int, name: str, color: list):
self.labelList.add(id, name, color)
def delLabel(self, id: int):
self.labelList.remove(id)
def clearLabel(self):
self.labelList.clear()
def importLabel(self, path):
self.labelList.importLabel(path)
def exportLabel(self, path):
self.labelList.exportLabel(path)
# 点击操作
def addClick(self, x: int, y: int, is_positive: bool):
"""添加一个点并运行推理,保存历史用于undo
Parameters
----------
x : int
点击的横坐标
y : int
点击的纵坐标
is_positive : bool
是否点的是正点
Returns
-------
bool, str
点击是否添加成功, 失败原因
"""
# 1. 确定可以点
if not self.inImage(x, y):
return False, "点击越界"
if not self.modelSet:
return False, "未加载模型"
if not self.imageSet:
return False, "图像未设置"
if len(self.states) == 0: # 保存一个空状态
self.states.append({
"clicker": self.clicker.get_state(),
"predictor": self.predictor.get_states(),
})
# 2. 添加点击,跑推理
click = clicker.Click(is_positive=is_positive, coords=(y, x))
self.clicker.add_click(click)
pred = self.predictor.get_prediction(self.clicker)
# 3. 保存状态
self.states.append({
"clicker": self.clicker.get_state(),
"predictor": self.predictor.get_states(),
})
if self.probs_history:
self.probs_history.append((self.probs_history[-1][1], pred))
else:
self.probs_history.append((np.zeros_like(pred), pred))
# 点击之后就不能接着之前的历史redo了
self.undo_states = []
self.undo_probs_history = []
return True, "点击添加成功"
def undoClick(self):
"""
undo一步点击
"""
if len(self.states) <= 1: # == 1就只剩下一个空状态了,不用再退
return
self.undo_states.append(self.states.pop())
self.clicker.set_state(self.states[-1]["clicker"])
self.predictor.set_states(self.states[-1]["predictor"])
self.undo_probs_history.append(self.probs_history.pop())
if not self.probs_history:
self.reset_init_mask()
def redoClick(self):
"""
redo一步点击
"""
if len(self.undo_states) == 0: # 如果还没撤销过
return
if len(self.undo_probs_history) >= 1:
next_state = self.undo_states.pop()
self.states.append(next_state)
self.clicker.set_state(next_state["clicker"])
self.predictor.set_states(next_state["predictor"])
self.probs_history.append(self.undo_probs_history.pop())
def finishObject(self, building=False):
"""
结束当前物体标注,准备标下一个
"""
object_prob = self.current_object_prob
if object_prob is None:
return None, None
object_mask = object_prob > self.prob_thresh
if self.lccFilter:
object_mask = self.getLargestCC(object_mask)
polygon = util.get_polygon(
(object_mask.astype(np.uint8) * 255),
img_size=object_mask.shape,
building=building)
if polygon is not None:
self._result_mask[object_mask] = self.curr_label_number
self.resetLastObject()
self.polygons.append([self.curr_label_number, polygon])
return object_mask, polygon
# 多边形
def getPolygon(self):
return self.polygon
def setPolygon(self, polygon):
self.polygon = polygon
# mask
def getMask(self):
s = self.imgShape
img = np.zeros([s[0], s[1]])
for poly in self.polygons:
pts = np.int32([np.array(poly[1])])
cv2.fillPoly(img, pts=pts, color=poly[0])
return img
def setCurrLabelIdx(self, number):
if not isinstance(number, int):
return False
self.curr_label_number = number
def resetLastObject(self, update_image=True):
"""
重置控制器状态
Parameters
update_image(bool): 是否更新图像
"""
self.states = []
self.probs_history = []
self.undo_states = []
self.undo_probs_history = []
# self.current_object_prob = None
self.clicker.reset_clicks()
self.reset_predictor()
self.reset_init_mask()
def reset_predictor(self, predictor_params=None):
"""
重置推理器,可以换推理配置
Parameters
predictor_params(dict): 推理配置
"""
if predictor_params is not None:
self.predictor_params = predictor_params
if self.model.model:
self.predictor = get_predictor(self.model.model,
**self.predictor_params)
if self.image is not None:
self.predictor.set_input_image(self.image)
def reset_init_mask(self):
self.clicker.click_indx_offset = 0
def getLargestCC(self, mask):
mask = label(mask)
if mask.max() == 0:
return mask
mask = mask == np.argmax(np.bincount(mask.flat)[1:]) + 1
return mask
def get_visualization(self, alpha_blend: float, click_radius: int):
if self.image is None:
return None
# 1. 正在标注的mask
# results_mask_for_vis = self.result_mask # 加入之前标完的mask
results_mask_for_vis = np.zeros_like(self.result_mask)
results_mask_for_vis *= self.curr_label_number
if self.probs_history:
results_mask_for_vis[self.current_object_prob >
self.prob_thresh] = self.curr_label_number
if self.lccFilter:
results_mask_for_vis = (self.getLargestCC(results_mask_for_vis) *
self.curr_label_number)
vis = draw_with_blend_and_clicks(
self.image,
mask=results_mask_for_vis,
alpha=alpha_blend,
clicks_list=self.clicker.clicks_list,
radius=click_radius,
palette=self.palette, )
return vis
def inImage(self, x: int, y: int):
s = self.image.shape
if x < 0 or y < 0 or x >= s[1] or y >= s[0]:
return False
return True
@property
def result_mask(self):
result_mask = self._result_mask.copy()
return result_mask
@property
def palette(self):
if self.labelList:
colors = [ml.color for ml in self.labelList]
colors.insert(0, [0, 0, 0])
else:
colors = [[0, 0, 0]]
return colors
@property
def current_object_prob(self):
"""
获取当前推理标签
"""
if self.probs_history:
_, current_prob_additive = self.probs_history[-1]
return current_prob_additive
else:
return None
@property
def is_incomplete_mask(self):
"""
Returns
bool: 当前的物体是不是还没标完
"""
return len(self.probs_history) > 0
@property
def imgShape(self):
return self.image.shape # [1::-1]
@property
def modelSet(self):
return self.model is not None
@property
def modelName(self):
return self.model.__name__
@property
def imageSet(self):
return self.image is not None
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import sys
sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
from run import main
if __name__ == "__main__":
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import cv2
import numpy as np
from copy import deepcopy
class Clicker(object):
def __init__(self,
gt_mask=None,
init_clicks=None,
ignore_label=-1,
click_indx_offset=0):
self.click_indx_offset = click_indx_offset
if gt_mask is not None:
self.gt_mask = gt_mask == 1
self.not_ignore_mask = gt_mask != ignore_label
else:
self.gt_mask = None
self.reset_clicks()
if init_clicks is not None:
for click in init_clicks:
self.add_click(click)
def make_next_click(self, pred_mask):
assert self.gt_mask is not None
click = self._get_next_click(pred_mask)
self.add_click(click)
def get_clicks(self, clicks_limit=None):
return self.clicks_list[:clicks_limit]
def _get_next_click(self, pred_mask, padding=True):
fn_mask = np.logical_and(
np.logical_and(self.gt_mask, np.logical_not(pred_mask)),
self.not_ignore_mask, )
fp_mask = np.logical_and(
np.logical_and(np.logical_not(self.gt_mask), pred_mask),
self.not_ignore_mask, )
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
fn_mask_dt = cv2.distanceTransform(
fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(
fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
fn_mask_dt = fn_mask_dt * self.not_clicked_map
fp_mask_dt = fp_mask_dt * self.not_clicked_map
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
if is_positive:
coords_y, coords_x = np.where(
fn_mask_dt == fn_max_dist) # coords is [y, x]
else:
coords_y, coords_x = np.where(
fp_mask_dt == fp_max_dist) # coords is [y, x]
return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
def add_click(self, click):
coords = click.coords
click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
if click.is_positive:
self.num_pos_clicks += 1
else:
self.num_neg_clicks += 1
self.clicks_list.append(click)
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = False
def _remove_last_click(self):
click = self.clicks_list.pop()
coords = click.coords
if click.is_positive:
self.num_pos_clicks -= 1
else:
self.num_neg_clicks -= 1
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = True
def reset_clicks(self):
if self.gt_mask is not None:
self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
self.num_pos_clicks = 0
self.num_neg_clicks = 0
self.clicks_list = []
def get_state(self):
return deepcopy(self.clicks_list)
def set_state(self, state):
self.reset_clicks()
for click in state:
self.add_click(click)
def __len__(self):
return len(self.clicks_list)
class Click:
def __init__(self, is_positive, coords, indx=None):
self.is_positive = is_positive
self.coords = coords
self.indx = indx
@property
def coords_and_indx(self):
return (*self.coords, self.indx)
def copy(self, **kwargs):
self_copy = deepcopy(self)
for k, v in kwargs.items():
setattr(self_copy, k, v)
return self_copy
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
from .base import BasePredictor
from inference.transforms import ZoomIn
def get_predictor(net,
brs_mode,
with_flip=False,
zoom_in_params=dict(),
predictor_params=None):
predictor_params_ = {"optimize_after_n_clicks": 1}
if zoom_in_params is not None:
zoom_in = ZoomIn(**zoom_in_params)
else:
zoom_in = None
if brs_mode == "NoBRS":
if predictor_params is not None:
predictor_params_.update(predictor_params)
predictor = BasePredictor(
net, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
else:
raise NotImplementedError("Just support NoBRS mode")
return predictor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import paddle.nn.functional as F
import numpy as np
from inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
from .ops import DistMaps, ScaleLayer, BatchImageNormalize
class BasePredictor(object):
def __init__(self,
model,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
with_mask=True,
**kwargs):
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.zoom_in = zoom_in
self.prev_prediction = None
self.model_indx = 0
self.click_models = None
self.net_state_dict = None
self.with_prev_mask = with_mask
self.net = model
if not paddle.in_dynamic_mode():
paddle.disable_static()
self.normalization = BatchImageNormalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
self.dist_maps = DistMaps(
norm_radius=5, spatial_scale=1.0, cpu_mode=False, use_disks=True)
def to_tensor(self, x):
if isinstance(x, np.ndarray):
if x.ndim == 2:
x = x[:, :, None]
img = paddle.to_tensor(x.transpose([2, 0, 1])).astype("float32") / 255
return img
def set_input_image(self, image):
image_nd = self.to_tensor(image)
for transform in self.transforms:
transform.reset()
self.original_image = image_nd
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
self.prev_prediction = paddle.zeros_like(self.original_image[:, :
1, :, :])
if not self.with_prev_mask:
self.prev_edge = paddle.zeros_like(self.original_image[:, :1, :, :])
def get_prediction(self, clicker, prev_mask=None):
clicks_list = clicker.get_clicks()
input_image = self.original_image
if prev_mask is None:
if not self.with_prev_mask:
prev_mask = self.prev_edge
else:
prev_mask = self.prev_prediction
input_image = paddle.concat([input_image, prev_mask], axis=1)
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list])
pred_logits, pred_edges = self._get_prediction(image_nd, clicks_lists,
is_image_changed)
pred_logits = paddle.to_tensor(pred_logits)
prediction = F.interpolate(
pred_logits,
mode="bilinear",
align_corners=True,
size=image_nd.shape[2:])
if pred_edges is not None:
pred_edge = paddle.to_tensor(pred_edges)
edge_prediction = F.interpolate(
pred_edge,
mode="bilinear",
align_corners=True,
size=image_nd.shape[2:])
for t in reversed(self.transforms):
if pred_edges is not None:
edge_prediction = t.inv_transform(edge_prediction)
self.prev_edge = edge_prediction
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(
):
return self.get_prediction(clicker)
self.prev_prediction = prediction
return prediction.numpy()[0, 0]
def prepare_input(self, image):
prev_mask = None
prev_mask = image[:, 3:, :, :]
image = image[:, :3, :, :]
image = self.normalization(image)
return image, prev_mask
def get_coord_features(self, image, prev_mask, points):
coord_features = self.dist_maps(image, points)
if prev_mask is not None:
coord_features = paddle.concat((prev_mask, coord_features), axis=1)
return coord_features
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
input_names = self.net.get_input_names()
self.input_handle_1 = self.net.get_input_handle(input_names[0])
self.input_handle_2 = self.net.get_input_handle(input_names[1])
points_nd = self.get_points_nd(clicks_lists)
image, prev_mask = self.prepare_input(image_nd)
coord_features = self.get_coord_features(image, prev_mask, points_nd)
image = image.numpy().astype("float32")
coord_features = coord_features.numpy().astype("float32")
self.input_handle_1.copy_from_cpu(image)
self.input_handle_2.copy_from_cpu(coord_features)
self.net.run()
output_names = self.net.get_output_names()
output_handle = self.net.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
if len(output_names) == 3:
edge_handle = self.net.get_output_handle(output_names[2])
edge_data = edge_handle.copy_to_cpu()
return output_data, edge_data
else:
return output_data, None
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [
sum(x.is_positive for x in clicks_list)
for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list
if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [
click.coords_and_indx for click in clicks_list
if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks)
return paddle.to_tensor(total_clicks)
def get_states(self):
return {
"transform_states": self._get_transform_states(),
"prev_prediction": self.prev_prediction,
}
def set_states(self, states):
self._set_transform_states(states["transform_states"])
self.prev_prediction = states["prev_prediction"]
def split_points_by_order(tpoints, groups):
points = tpoints.numpy()
num_groups = len(groups)
bs = points.shape[0]
num_points = points.shape[1] // 2
groups = [x if x > 0 else num_points for x in groups]
group_points = [
np.full(
(bs, 2 * x, 3), -1, dtype=np.float32) for x in groups
]
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
for group_indx, group_size in enumerate(groups):
last_point_indx_group[:, group_indx, 1] = group_size
for bindx in range(bs):
for pindx in range(2 * num_points):
point = points[bindx, pindx, :]
group_id = int(point[2])
if group_id < 0:
continue
is_negative = int(pindx >= num_points)
if group_id >= num_groups or (
group_id == 0 and
is_negative): # disable negative first click
group_id = num_groups - 1
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
last_point_indx_group[bindx, group_id, is_negative] += 1
group_points[group_id][bindx, new_point_indx, :] = point
group_points = [
paddle.to_tensor(
x, dtype=tpoints.dtype) for x in group_points
]
return group_points
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import paddle.nn as nn
import numpy as np
class DistMaps(nn.Layer):
def __init__(self,
norm_radius,
spatial_scale=1.0,
cpu_mode=True,
use_disks=False):
super(DistMaps, self).__init__()
self.spatial_scale = spatial_scale
self.norm_radius = norm_radius
self.cpu_mode = cpu_mode
self.use_disks = use_disks
if self.cpu_mode:
from util.cython import get_dist_maps
self._get_dist_maps = get_dist_maps
def get_coord_features(self, points, batchsize, rows, cols):
if self.cpu_mode:
coords = []
for i in range(batchsize):
norm_delimeter = (1.0 if self.use_disks else
self.spatial_scale * self.norm_radius)
coords.append(
self._get_dist_maps(points[i].numpy().astype("float32"),
rows, cols, norm_delimeter))
coords = paddle.to_tensor(np.stack(
coords, axis=0)).astype("float32")
else:
num_points = points.shape[1] // 2
points = points.reshape([-1, points.shape[2]])
points, points_order = paddle.split(points, [2, 1], axis=1)
invalid_points = paddle.max(points, axis=1, keepdim=False) < 0
row_array = paddle.arange(
start=0, end=rows, step=1, dtype="float32")
col_array = paddle.arange(
start=0, end=cols, step=1, dtype="float32")
coord_rows, coord_cols = paddle.meshgrid(row_array, col_array)
coords = paddle.unsqueeze(
paddle.stack(
[coord_rows, coord_cols], axis=0),
axis=0).tile([points.shape[0], 1, 1, 1])
add_xy = (points * self.spatial_scale).reshape(
[points.shape[0], points.shape[1], 1, 1])
coords = coords - add_xy
if not self.use_disks:
coords = coords / (self.norm_radius * self.spatial_scale)
coords = coords * coords
coords[:, 0] += coords[:, 1]
coords = coords[:, :1]
invalid_points = invalid_points.numpy()
coords[invalid_points, :, :, :] = 1e6
coords = coords.reshape([-1, num_points, 1, rows, cols])
coords = paddle.min(coords, axis=1)
coords = coords.reshape([-1, 2, rows, cols])
if self.use_disks:
coords = (
coords <=
(self.norm_radius * self.spatial_scale)**2).astype("float32")
else:
coords = paddle.tanh(paddle.sqrt(coords) * 2)
return coords
def forward(self, x, coords):
return self.get_coord_features(coords, x.shape[0], x.shape[2],
x.shape[3])
class ScaleLayer(nn.Layer):
def __init__(self, init_value=1.0, lr_mult=1):
super().__init__()
self.lr_mult = lr_mult
self.scale = self.create_parameter(
shape=[1],
dtype="float32",
default_initializer=nn.initializer.Constant(init_value / lr_mult), )
def forward(self, x):
scale = paddle.abs(self.scale * self.lr_mult)
return x * scale
class BatchImageNormalize:
def __init__(self, mean, std):
self.mean = paddle.to_tensor(
np.array(mean)[np.newaxis, :, np.newaxis, np.newaxis]).astype(
"float32")
self.std = paddle.to_tensor(
np.array(std)[np.newaxis, :, np.newaxis, np.newaxis]).astype(
"float32")
def __call__(self, tensor):
tensor = (tensor - self.mean) / self.std
return tensor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import SigmoidForPred
from .flip import AddHorizontalFlip
from .zoom_in import ZoomIn
from .limit_longest_side import LimitLongestSide
from .crops import Crops
import paddle.nn.functional as F
class BaseTransform(object):
def __init__(self):
self.image_changed = False
def transform(self, image_nd, clicks_lists):
raise NotImplementedError
def inv_transform(self, prob_map):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
class SigmoidForPred(BaseTransform):
def transform(self, image_nd, clicks_lists):
return image_nd, clicks_lists
def inv_transform(self, prob_map):
return F.sigmoid(prob_map)
def reset(self):
pass
def get_state(self):
return None
def set_state(self, state):
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import math
import paddle
import numpy as np
from inference.clicker import Click
from .base import BaseTransform
class Crops(BaseTransform):
def __init__(self, crop_size=(320, 480), min_overlap=0.2):
super().__init__()
self.crop_height, self.crop_width = crop_size
self.min_overlap = min_overlap
self.x_offsets = None
self.y_offsets = None
self._counts = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_height, image_width = image_nd.shape[2:4]
self._counts = None
if image_height < self.crop_height or image_width < self.crop_width:
return image_nd, clicks_lists
self.x_offsets = get_offsets(image_width, self.crop_width,
self.min_overlap)
self.y_offsets = get_offsets(image_height, self.crop_height,
self.min_overlap)
self._counts = np.zeros((image_height, image_width))
image_crops = []
for dy in self.y_offsets:
for dx in self.x_offsets:
self._counts[dy:dy + self.crop_height, dx:dx +
self.crop_width] += 1
image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx +
self.crop_width]
image_crops.append(image_crop)
image_crops = paddle.concat(image_crops, axis=0)
self._counts = paddle.to_tensor(self._counts, dtype="float32")
clicks_list = clicks_lists[0]
clicks_lists = []
for dy in self.y_offsets:
for dx in self.x_offsets:
crop_clicks = [
x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx))
for x in clicks_list
]
clicks_lists.append(crop_clicks)
return image_crops, clicks_lists
def inv_transform(self, prob_map):
if self._counts is None:
return prob_map
new_prob_map = paddle.zeros(
(1, 1, *self._counts.shape), dtype=prob_map.dtype)
crop_indx = 0
for dy in self.y_offsets:
for dx in self.x_offsets:
new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx +
self.crop_width] += prob_map[crop_indx, 0]
crop_indx += 1
new_prob_map = paddle.divide(new_prob_map, self._counts)
return new_prob_map
def get_state(self):
return self.x_offsets, self.y_offsets, self._counts
def set_state(self, state):
self.x_offsets, self.y_offsets, self._counts = state
def reset(self):
self.x_offsets = None
self.y_offsets = None
self._counts = None
def get_offsets(length, crop_size, min_overlap_ratio=0.2):
if length == crop_size:
return [0]
N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
N = math.ceil(N)
overlap_ratio = (N - length / crop_size) / (N - 1)
overlap_width = int(crop_size * overlap_ratio)
offsets = [0]
for i in range(1, N):
new_offset = offsets[-1] + crop_size - overlap_width
if new_offset + crop_size > length:
new_offset = length - crop_size
offsets.append(new_offset)
return offsets
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
from inference.clicker import Click
from .base import BaseTransform
class AddHorizontalFlip(BaseTransform):
def transform(self, image_nd, clicks_lists):
assert len(image_nd.shape) == 4
image_nd = paddle.concat(
[image_nd, paddle.flip(
image_nd, axis=[3])], axis=0)
image_width = image_nd.shape[3]
clicks_lists_flipped = []
for clicks_list in clicks_lists:
clicks_list_flipped = [
click.copy(coords=(click.coords[0],
image_width - click.coords[1] - 1))
for click in clicks_list
]
clicks_lists_flipped.append(clicks_list_flipped)
clicks_lists = clicks_lists + clicks_lists_flipped
return image_nd, clicks_lists
def inv_transform(self, prob_map):
assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
num_maps = prob_map.shape[0] // 2
prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
return 0.5 * (prob_map + paddle.flip(prob_map_flipped, axis=[3]))
def get_state(self):
return None
def set_state(self, state):
pass
def reset(self):
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
from .zoom_in import ZoomIn, get_roi_image_nd
class LimitLongestSide(ZoomIn):
def __init__(self, max_size=800):
super().__init__(target_size=max_size, skip_clicks=0)
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_max_size = max(image_nd.shape[2:4])
self.image_changed = False
if image_max_size <= self.target_size:
return image_nd, clicks_lists
self._input_image = image_nd
self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
self._roi_image = get_roi_image_nd(image_nd, self._object_roi,
self.target_size)
self.image_changed = True
tclicks_lists = [self._transform_clicks(clicks_lists[0])]
return self._roi_image, tclicks_lists
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/saic-vul/ritm_interactive_segmentation
Ths copyright of saic-vul/ritm_interactive_segmentation is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import numpy as np
from inference.clicker import Click
from util.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
from .base import BaseTransform
class ZoomIn(BaseTransform):
def __init__(
self,
target_size=700,
skip_clicks=1,
expansion_ratio=1.4,
min_crop_size=480,
recompute_thresh_iou=0.5,
prob_thresh=0.50, ):
super().__init__()
self.target_size = target_size
self.min_crop_size = min_crop_size
self.skip_clicks = skip_clicks
self.expansion_ratio = expansion_ratio
self.recompute_thresh_iou = recompute_thresh_iou
self.prob_thresh = prob_thresh
self._input_image_shape = None
self._prev_probs = None
self._object_roi = None
self._roi_image = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
self.image_changed = False
clicks_list = clicks_lists[0]
if len(clicks_list) <= self.skip_clicks:
return image_nd, clicks_lists
self._input_image_shape = image_nd.shape
current_object_roi = None
if self._prev_probs is not None:
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if current_pred_mask.sum() > 0:
current_object_roi = get_object_roi(
current_pred_mask,
clicks_list,
self.expansion_ratio,
self.min_crop_size, )
if current_object_roi is None:
if self.skip_clicks >= 0:
return image_nd, clicks_lists
else:
current_object_roi = 0, image_nd.shape[
2] - 1, 0, image_nd.shape[3] - 1
update_object_roi = False
if self._object_roi is None:
update_object_roi = True
elif not check_object_roi(self._object_roi, clicks_list):
update_object_roi = True
elif (get_bbox_iou(current_object_roi, self._object_roi) <
self.recompute_thresh_iou):
update_object_roi = True
if update_object_roi:
self._object_roi = current_object_roi
self.image_changed = True
self._roi_image = get_roi_image_nd(image_nd, self._object_roi,
self.target_size)
tclicks_lists = [self._transform_clicks(clicks_list)]
return self._roi_image, tclicks_lists
def inv_transform(self, prob_map):
if self._object_roi is None:
self._prev_probs = prob_map.numpy()
return prob_map
assert prob_map.shape[0] == 1
rmin, rmax, cmin, cmax = self._object_roi
prob_map = paddle.nn.functional.interpolate(
prob_map,
size=(rmax - rmin + 1, cmax - cmin + 1),
mode="bilinear",
align_corners=True, )
if self._prev_probs is not None:
new_prob_map = paddle.zeros(
shape=self._prev_probs.shape, dtype=prob_map.dtype)
new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
else:
new_prob_map = prob_map
self._prev_probs = new_prob_map.numpy()
return new_prob_map
def check_possible_recalculation(self):
if (self._prev_probs is None or self._object_roi is not None or
self.skip_clicks > 0):
return False
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if pred_mask.sum() > 0:
possible_object_roi = get_object_roi(
pred_mask, [], self.expansion_ratio, self.min_crop_size)
image_roi = (
0,
self._input_image_shape[2] - 1,
0,
self._input_image_shape[3] - 1, )
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
return True
return False
def get_state(self):
roi_image = self._roi_image if self._roi_image is not None else None
return (
self._input_image_shape,
self._object_roi,
self._prev_probs,
roi_image,
self.image_changed, )
def set_state(self, state):
(
self._input_image_shape,
self._object_roi,
self._prev_probs,
self._roi_image,
self.image_changed, ) = state
def reset(self):
self._input_image_shape = None
self._object_roi = None
self._prev_probs = None
self._roi_image = None
self.image_changed = False
def _transform_clicks(self, clicks_list):
if self._object_roi is None:
return clicks_list
rmin, rmax, cmin, cmax = self._object_roi
crop_height, crop_width = self._roi_image.shape[2:]
transformed_clicks = []
for click in clicks_list:
new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
transformed_clicks.append(click.copy(coords=(new_r, new_c)))
return transformed_clicks
def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
pred_mask = pred_mask.copy()
for click in clicks_list:
if click.is_positive:
pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
bbox = get_bbox_from_mask(pred_mask)
bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
h, w = pred_mask.shape[0], pred_mask.shape[1]
bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
return bbox
def get_roi_image_nd(image_nd, object_roi, target_size):
rmin, rmax, cmin, cmax = object_roi
height = rmax - rmin + 1
width = cmax - cmin + 1
if isinstance(target_size, tuple):
new_height, new_width = target_size
else:
scale = target_size / max(height, width)
new_height = int(round(height * scale))
new_width = int(round(width * scale))
with paddle.no_grad():
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
roi_image_nd = paddle.nn.functional.interpolate(
roi_image_nd,
size=(new_height, new_width),
mode="bilinear",
align_corners=True, )
return roi_image_nd
def check_object_roi(object_roi, clicks_list):
for click in clicks_list:
if click.is_positive:
if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[
1]:
return False
if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[
3]:
return False
return True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
from abc import abstractmethod
import paddle.inference as paddle_infer
here = osp.dirname(osp.abspath(__file__))
class EISegModel:
@abstractmethod
def __init__(self, model_path, param_path, use_gpu=False):
model_path, param_path = self.check_param(model_path, param_path)
try:
config = paddle_infer.Config(model_path, param_path)
except:
ValueError(" 模型和参数不匹配,请检查模型和参数是否加载错误")
if not use_gpu:
config.enable_mkldnn()
# TODO: fluid要废弃了,研究判断方式
# if paddle.fluid.core.supports_bfloat16():
# config.enable_mkldnn_bfloat16()
config.switch_ir_optim(True)
config.set_cpu_math_library_num_threads(10)
else:
config.enable_use_gpu(500, 0)
config.delete_pass("conv_elementwise_add_act_fuse_pass")
config.delete_pass("conv_elementwise_add2_act_fuse_pass")
config.delete_pass("conv_elementwise_add_fuse_pass")
config.switch_ir_optim()
config.enable_memory_optim()
# use_tensoret = False # TODO: 目前Linux和windows下使用TensorRT报错
# if use_tensoret:
# config.enable_tensorrt_engine(
# workspace_size=1 << 30,
# precision_mode=paddle_infer.PrecisionType.Float32,
# max_batch_size=1,
# min_subgraph_size=5,
# use_static=False,
# use_calib_mode=False,
# )
self.model = paddle_infer.create_predictor(config)
def check_param(self, model_path, param_path):
if model_path is None or not osp.exists(model_path):
raise Exception(f"模型路径{model_path}不存在。请指定正确的模型路径")
if param_path is None or not osp.exists(param_path):
raise Exception(f"权重路径{param_path}不存在。请指定正确的权重路径")
return model_path, param_path
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .med import has_sitk, dcm_reader, windowlize
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
from eiseg import logger
def has_sitk():
try:
import SimpleITK
return True
except ImportError:
return False
if has_sitk():
import SimpleITK as sitk
def dcm_reader(path):
logger.debug(f"opening medical image {path}")
reader = sitk.ImageSeriesReader()
reader.SetFileNames([path])
image = reader.Execute()
img = sitk.GetArrayFromImage(image)
logger.debug(f"scan shape is {img.shape}")
if len(img.shape) == 4:
img = img[0]
# WHC
img = np.transpose(img, [1, 2, 0])
return img.astype(np.int32)
def windowlize(scan, ww, wc):
wl = wc - ww / 2
wh = wc + ww / 2
res = scan.copy()
res = res.astype(np.float32)
res = np.clip(res, wl, wh)
res = (res - wl) / ww * 255
res = res.astype(np.uint8)
# print("++", res.shape)
# for idx in range(res.shape[-1]):
# TODO: 支持3d或者改调用
res = cv2.cvtColor(res, cv2.COLOR_GRAY2BGR)
return res
# def open_nii(niiimg_path):
# if IPT_SITK == True:
# sitk_image = sitk.ReadImage(niiimg_path)
# return _nii2arr(sitk_image)
# else:
# raise ImportError("can't import SimpleITK!")
#
# def _nii2arr(sitk_image):
# if IPT_SITK == True:
# img = sitk.GetArrayFromImage(sitk_image).transpose((1, 2, 0))
# return img
# else:
# raise ImportError("can't import SimpleITK!")
#
#
# def slice_img(img, index):
# if index == 0:
# return sample_norm(
# cv2.merge(
# [
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index + 1]),
# ]
# )
# )
# elif index == img.shape[2] - 1:
# return sample_norm(
# cv2.merge(
# [
# np.uint16(img[:, :, index - 1]),
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index]),
# ]
# )
# )
# else:
# return sample_norm(
# cv2.merge(
# [
# np.uint16(img[:, :, index - 1]),
# np.uint16(img[:, :, index]),
# np.uint16(img[:, :, index + 1]),
# ]
# )
# )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment