Commit 0d97cc8c authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
Pipeline #316 failed with stages
in 0 seconds
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .rs_grid import RSGrids
from .grid import Grids, checkOpenGrid
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from PIL import Image
def checkOpenGrid(img, thumbnail_min):
H, W = img.shape[:2]
if max(H, W) <= thumbnail_min:
return False
else:
return True
class Grids:
def __init__(self, img, gridSize=(512, 512), overlap=(24, 24)):
self.clear()
self.detimg = img
self.gridSize = np.array(gridSize)
self.overlap = np.array(overlap)
def clear(self):
# 图像HWC格式
self.detimg = None # 宫格初始图像
self.grid_init = False # 是否初始化了宫格
# self.imagesGrid = [] # 图像宫格
self.mask_grids = [] # 标签宫格
self.json_labels = [] # 保存标签
self.grid_count = None # (row count, col count)
self.curr_idx = None # (current row, current col)
def createGrids(self):
# 计算宫格横纵向格数
imgSize = np.array(self.detimg.shape[:2])
grid_count = np.ceil((imgSize + self.overlap) / self.gridSize)
self.grid_count = grid_count = grid_count.astype("uint16")
# ul = self.overlap - self.gridSize
# for row in range(grid_count[0]):
# ul[0] = ul[0] + self.gridSize[0] - self.overlap[0]
# for col in range(grid_count[1]):
# ul[1] = ul[1] + self.gridSize[1] - self.overlap[1]
# lr = ul + self.gridSize
# # print("ul, lr", ul, lr)
# # 扩充
# det_tmp = self.detimg[ul[0]: lr[0], ul[1]: lr[1]]
# tmp = np.zeros((self.gridSize[0], self.gridSize[1], self.detimg.shape[-1]))
# tmp[:det_tmp.shape[0], :det_tmp.shape[1], :] = det_tmp
# self.imagesGrid.append(tmp)
# self.mask_grids = [[np.zeros(self.gridSize)] * grid_count[1]] * grid_count[0] # 不能用浅拷贝
self.mask_grids = [
[np.zeros(self.gridSize) for _ in range(grid_count[1])]
for _ in range(grid_count[0])
]
# print(len(self.mask_grids), len(self.mask_grids[0]))
self.grid_init = True
return list(grid_count)
def getGrid(self, row, col):
gridIdx = np.array([row, col])
ul = gridIdx * (self.gridSize - self.overlap)
lr = ul + self.gridSize
img = self.detimg[ul[0]:lr[0], ul[1]:lr[1]]
mask = self.mask_grids[row][col]
self.curr_idx = (row, col)
return img, mask
def splicingList(self, save_path):
"""
将slide的out进行拼接,raw_size保证恢复到原状
"""
imgs = self.mask_grids
raw_size = self.detimg.shape[:2]
# h, w = None, None
# for i in range(len(imgs)):
# for j in range(len(imgs[i])):
# im = imgs[i][j]
# if im is not None:
# h, w = im.shape[:2]
# break
# if h is None and w is None:
# return False
h, w = self.gridSize
row = math.ceil(raw_size[0] / h)
col = math.ceil(raw_size[1] / w)
result_1 = np.zeros((h * row, w * col), dtype=np.uint8)
result_2 = result_1.copy()
# k = 0
for i in range(row):
for j in range(col):
ih, iw = imgs[i][j].shape[:2]
im = np.zeros(self.gridSize)
im[:ih, :iw] = imgs[i][j]
start_h = (i * h) if i == 0 else (i * (h - self.overlap[0]))
end_h = start_h + h
start_w = (j * w) if j == 0 else (j * (w - self.overlap[1]))
end_w = start_w + w
# 单区自己,重叠取或
if (i + j) % 2 == 0:
result_1[start_h:end_h, start_w:end_w] = im
else:
result_2[start_h:end_h, start_w:end_w] = im
result = np.where(result_2 != 0, result_2, result_1)
result = result[:raw_size[0], :raw_size[1]]
Image.fromarray(result).save(save_path, "PNG")
return result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from typing import List, Tuple
from eiseg.plugin.remotesensing.raster import Raster
class RSGrids:
def __init__(self, raset: Raster) -> None:
""" 在EISeg中用于处理遥感栅格数据的宫格类.
参数:
tif_path (str): GTiff数据的路径.
show_band (Union[List[int], Tuple[int]], optional): 用于RGB合成显示的波段. 默认为 [1, 1, 1].
grid_size (Union[List[int], Tuple[int]], optional): 切片大小. 默认为 [512, 512].
overlap (Union[List[int], Tuple[int]], optional): 重叠区域的大小. 默认为 [24, 24].
"""
super(RSGrids, self).__init__()
self.raster = raset
self.clear()
def clear(self) -> None:
self.mask_grids = [] # 标签宫格
self.json_labels = [] # 保存标签
self.grid_count = None # (row count, col count)
self.curr_idx = None # (current row, current col)
def createGrids(self) -> List[int]:
img_size = (self.raster.geoinfo.ysize, self.raster.geoinfo.xsize)
grid_count = np.ceil(
(img_size + self.raster.overlap) / self.raster.grid_size)
self.grid_count = grid_count = grid_count.astype("uint16")
self.mask_grids = [[np.zeros(self.raster.grid_size) \
for _ in range(grid_count[1])] for _ in range(grid_count[0])]
return list(grid_count)
def getGrid(self, row: int, col: int) -> Tuple[np.ndarray]:
img, _ = self.raster.getGrid(row, col)
mask = self.mask_grids[row][col]
self.curr_idx = (row, col)
return img, mask
def splicingList(self, save_path: str) -> np.ndarray:
mask = self.raster.saveMaskbyGrids(self.mask_grids, save_path,
self.raster.geoinfo)
return mask
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .imgtools import *
from .shape import *
from .raster import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
from skimage import exposure
# 2%线性拉伸
def two_percentLinear(image: np.ndarray, max_out: int=255,
min_out: int=0) -> np.ndarray:
b, g, r = cv2.split(image)
def __gray_process(gray, maxout=max_out, minout=min_out):
high_value = np.percentile(gray, 98) # 取得98%直方图处对应灰度
low_value = np.percentile(gray, 2)
truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
processed_gray = ((truncated_gray - low_value) /
(high_value - low_value)) * (maxout - minout)
return processed_gray
r_p = __gray_process(r)
g_p = __gray_process(g)
b_p = __gray_process(b)
result = cv2.merge((b_p, g_p, r_p))
return np.uint8(result)
# 简单图像标准化
def sample_norm(image: np.ndarray) -> np.ndarray:
stretches = []
if len(image.shape) == 3:
for b in range(image.shape[-1]):
stretched = exposure.equalize_hist(image[:, :, b])
stretched /= float(np.max(stretched))
stretches.append(stretched)
stretched_img = np.stack(stretches, axis=2)
else: # if len(image.shape) == 2
stretched_img = exposure.equalize_hist(image)
return np.uint8(stretched_img * 255)
# 计算缩略图
def get_thumbnail(image: np.ndarray, range: int=2000,
max_size: int=1000) -> np.ndarray:
h, w = image.shape[:2]
if h >= range or w >= range:
if h >= w:
image = cv2.resize(image, (int(max_size / h * w), max_size))
else:
image = cv2.resize(image, (max_size, int(max_size / w * h)))
return image
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import numpy as np
import cv2
import math
from typing import List, Dict, Tuple, Union
from collections import defaultdict
from easydict import EasyDict as edict
from .imgtools import sample_norm, two_percentLinear, get_thumbnail
def check_rasterio() -> bool:
try:
import rasterio
return True
except:
return False
IMPORT_STATE = False
if check_rasterio():
import rasterio
from rasterio.windows import Window
IMPORT_STATE = True
class Raster:
def __init__(self,
tif_path: str,
show_band: Union[List[int], Tuple[int]]=[1, 1, 1],
open_grid: bool=False,
grid_size: Union[List[int], Tuple[int]]=[512, 512],
overlap: Union[List[int], Tuple[int]]=[24, 24]) -> None:
""" 在EISeg中用于处理遥感栅格数据的类.
参数:
tif_path (str): GTiff数据的路径.
show_band (Union[List[int], Tuple[int]], optional): 用于RGB合成显示的波段. 默认为 [1, 1, 1].
open_grid (bool, optional): 是否打开了宫格切片功能. 默认为 False.
grid_size (Union[List[int], Tuple[int]], optional): 切片大小. 默认为 [512, 512].
overlap (Union[List[int], Tuple[int]], optional): 重叠区域的大小. 默认为 [24, 24].
"""
super(Raster, self).__init__()
if IMPORT_STATE is False:
raise ("Can't import rasterio!")
if osp.exists(tif_path):
self.src_data = rasterio.open(tif_path)
self.geoinfo = self.__getRasterInfo()
self.show_band = list(show_band)
self.grid_size = np.array(grid_size)
self.overlap = np.array(overlap)
self.open_grid = open_grid
else:
raise ("{0} not exists!".format(tif_path))
self.thumbnail_min = 2000
def __del__(self) -> None:
self.src_data.close()
def __getRasterInfo(self) -> Dict:
meta = self.src_data.meta
geoinfo = edict()
geoinfo.count = meta["count"]
geoinfo.dtype = meta["dtype"]
geoinfo.xsize = meta["width"]
geoinfo.ysize = meta["height"]
geoinfo.geotf = meta["transform"]
geoinfo.crs = meta["crs"]
if geoinfo.crs is not None:
geoinfo.crs_wkt = geoinfo.crs.wkt
else:
geoinfo.crs_wkt = None
return geoinfo
def checkOpenGrid(self, thumbnail_min: Union[int, None]) -> bool:
if isinstance(thumbnail_min, int):
self.thumbnail_min = thumbnail_min
if max(self.geoinfo.xsize, self.geoinfo.ysize) <= self.thumbnail_min:
self.open_grid = False
else:
self.open_grid = True
return self.open_grid
def setBand(self, bands: Union[List[int], Tuple[int]]) -> None:
self.show_band = list(bands)
# def __analysis_proj4(self) -> str:
# proj4 = self.geoinfo.crs.wkt # TODO: 解析为proj4
# ap_dict = defaultdict(str)
# dinf = proj4.split("+")
# for df in dinf:
# kv = df.strip().split("=")
# if len(kv) == 2:
# k, v = kv
# ap_dict[k] = v
# return str("● 投影:{0}\n● 基准:{1}\n● 单位:{2}".format(
# ap_dict["proj"], ap_dict["datum"], ap_dict["units"])
# )
def showGeoInfo(self) -> str:
# return str("● 波段数:{0}\n● 数据类型:{1}\n● 行数:{2}\n● 列数:{3}\n{4}".format(
# self.geoinfo.count, self.geoinfo.dtype, self.geoinfo.xsize,
# self.geoinfo.ysize, self.__analysis_proj4())
# )
if self.geoinfo.crs is not None:
crs = str(self.geoinfo.crs.to_string().split(":")[-1])
else:
crs = "None"
return (str(self.geoinfo.count), str(self.geoinfo.dtype),
str(self.geoinfo.xsize), str(self.geoinfo.ysize), crs)
def getArray(self) -> Tuple[np.ndarray]:
rgb = []
if not self.open_grid:
for b in self.show_band:
rgb.append(self.src_data.read(b))
geotf = self.geoinfo.geotf
else:
for b in self.show_band:
rgb.append(
get_thumbnail(self.src_data.read(b), self.thumbnail_min))
geotf = None
ima = np.stack(rgb, axis=2) # cv2.merge(rgb)
if self.geoinfo["dtype"] != "uint8":
ima = sample_norm(ima)
return two_percentLinear(ima), geotf
def getGrid(self, row: int, col: int) -> Tuple[np.ndarray]:
if self.open_grid is False:
return self.getArray()
grid_idx = np.array([row, col])
ul = grid_idx * (self.grid_size - self.overlap)
lr = ul + self.grid_size
window = Window(ul[1], ul[0], (lr[1] - ul[1]), (lr[0] - ul[0]))
rgb = []
for b in self.show_band:
rgb.append(self.src_data.read(b, window=window))
win_tf = self.src_data.window_transform(window)
ima = cv2.merge([np.uint16(c) for c in rgb])
if self.geoinfo["dtype"] == "uint32":
ima = sample_norm(ima)
return two_percentLinear(ima), win_tf
def saveMask(self,
img: np.array,
save_path: str,
geoinfo: Union[Dict, None]=None,
count: int=1) -> None:
if geoinfo is None:
geoinfo = self.geoinfo
new_meta = self.src_data.meta.copy()
new_meta.update({
"driver": "GTiff",
"width": geoinfo.xsize,
"height": geoinfo.ysize,
"count": count,
"dtype": geoinfo.dtype,
"crs": geoinfo.crs,
"transform": geoinfo.geotf[:6],
"nodata": 0
})
img = np.nan_to_num(img).astype("int16")
with rasterio.open(save_path, "w", **new_meta) as tf:
if count == 1:
tf.write(img, indexes=1)
else:
for i in range(count):
tf.write(img[:, :, i], indexes=(i + 1))
def saveMaskbyGrids(self,
img_list: List[List[np.ndarray]],
save_path: Union[str, None]=None,
geoinfo: Union[Dict, None]=None) -> np.ndarray:
if geoinfo is None:
geoinfo = self.geoinfo
raw_size = (geoinfo.ysize, geoinfo.xsize)
h, w = self.grid_size
row = math.ceil(raw_size[0] / h)
col = math.ceil(raw_size[1] / w)
result_1 = np.zeros((h * row, w * col), dtype=np.uint8)
result_2 = result_1.copy()
for i in range(row):
for j in range(col):
ih, iw = img_list[i][j].shape[:2]
im = np.zeros(self.grid_size)
im[:ih, :iw] = img_list[i][j]
start_h = (i * h) if i == 0 else (i * (h - self.overlap[0]))
end_h = start_h + h
start_w = (j * w) if j == 0 else (j * (w - self.overlap[1]))
end_w = start_w + w
# 单区自己,重叠取或
if (i + j) % 2 == 0:
result_1[start_h:end_h, start_w:end_w] = im
else:
result_2[start_h:end_h, start_w:end_w] = im
result = np.where(result_2 != 0, result_2, result_1)
result = result[:raw_size[0], :raw_size[1]]
if save_path is not None:
self.saveMask(result, save_path, geoinfo)
return result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
def check_gdal() -> bool:
try:
import gdal
except:
try:
from osgeo import gdal
except ImportError:
return False
return True
IMPORT_STATE = False
if check_gdal():
try:
import gdal
import osr
import ogr
except:
from osgeo import gdal, osr, ogr
IMPORT_STATE = True
# 保存shp文件
def save_shp(shp_path: str, tif_path: str, ignore_index: int=0) -> str:
if IMPORT_STATE == True:
ds = gdal.Open(tif_path)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
ogr.RegisterAll()
drv = ogr.GetDriverByName("ESRI Shapefile")
if osp.exists(shp_path):
os.remove(shp_path)
dst_ds = drv.CreateDataSource(shp_path)
prosrs = osr.SpatialReference(wkt=ds.GetProjection())
dst_layer = dst_ds.CreateLayer(
"segmentation", geom_type=ogr.wkbPolygon, srs=prosrs)
dst_fieldname = "DN"
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
lyr = dst_ds.GetLayer()
lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
for holes in lyr:
lyr.DeleteFeature(holes.GetFID())
dst_ds.Destroy()
ds = None
return "Dataset creation successfully!"
else:
raise ImportError("can't import gdal, osr, ogr!")
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS # Users should be careful about adopting these functions in any commercial matters. # https://github.com/hkchengrex/MiVOS/blob/main/LICENSE from .inference_core import InferenceCore from .video_tools import overlay_davis
\ No newline at end of file
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE
import os
import paddle.nn.functional as F
import numpy as np
from .load_model import *
from .util.tensor_util import pad_divide_by, images_to_paddle
from .video_tools import load_video, aggregate_wbg
class InferenceCore:
"""
images - leave them in original dimension (unpadded), but do normalize them.
Should be CPU tensors of shape B*T*3*H*W
mem_profile - How extravagant I can use the GPU memory.
Usually more memory -> faster speed but I have not drawn the exact relation
0 - Use the most memory
1 - Intermediate, larger buffer
2 - Intermediate, small buffer
3 - Use the minimal amount of GPU memory
Note that *none* of the above options will affect the accuracy
This is a space-time tradeoff, not a space-performance one
mem_freq - Period at which new memory are put in the bank
Higher number -> less memory usage
Unlike the last option, this *is* a space-performance tradeoff
"""
def __init__(self, mem_profile=2, mem_freq=5):
self.cursur = 0
self.mem_freq = mem_freq
if mem_profile == 0:
self.q_buf_size = 105
self.i_buf_size = -1
elif mem_profile == 1:
self.q_buf_size = 105
self.i_buf_size = 105
elif mem_profile == 2:
self.q_buf_size = 3
self.i_buf_size = 3
else:
self.q_buf_size = 1
self.i_buf_size = 1
self.query_buf = {}
self.image_buf = {}
self.interacted = set()
self.certain_mem_k = None
self.certain_mem_v = None
self.prob = None
self.fuse_net = None
self.k = 1
def reset(self):
self.cursur = 0
self.query_buf = {}
self.image_buf = {}
self.interacted = set()
self.certain_mem_k = None
self.certain_mem_v = None
self.prob = None
# self.fuse_net = None
self.k = 1
self.images = None
self.masks = None
self.np_masks = None
def set_video(self, video_path):
self.images, fps = load_video(video_path)
self.num_frames, self.height, self.width = self.images.shape[:3]
return self.images, fps
def get_one_frames(self, idx):
return self.images[idx]
def check_match(self, param_key, model_key):
for p, m in zip(param_key, model_key):
if p != m[0]:
print(p)
print(m[0])
raise Exception("权重和模型不匹配。请确保指定的权重和模型对应")
return True
def set_model(self, param_path=None):
if param_path is None or not os.path.exists(param_path):
raise Exception(f"权重路径{param_path}不存在。请指定正确的模型路径")
# param path
param_path = os.path.abspath(os.path.dirname(param_path))
# **************memorize**********************
memory_model_path = os.path.join(param_path,
'static_propagation_memorize.pdmodel')
memory_param_path = os.path.join(
param_path, 'static_propagation_memorize.pdiparams')
self.prop_net_memory = load_model(memory_model_path, memory_param_path)
# **************segmentation**********************
segment_model_path = os.path.join(param_path,
"static_propagation_segment.pdmodel")
segment_param_path = os.path.join(
param_path, "static_propagation_segment.pdiparams")
self.prop_net_segm = load_model(segment_model_path, segment_param_path)
# **************attention**********************
attn_model_path = os.path.join(param_path,
'static_propagation_attention')
self.prop_net_attn = jit_load(attn_model_path)
# **************fusion**********************
fusion_model_path = os.path.join(param_path, 'static_fusion.pdmodel')
fusion_param_path = os.path.join(param_path, 'static_fusion.pdiparams')
self.fuse_net = load_model(fusion_model_path, fusion_param_path)
return True, "模型设置成功"
def set_images(self, images):
# True dimensions
images = images_to_paddle(images)
t = images.shape[1]
h, w = images.shape[-2:]
# Pad each side to multiples of 16
images = paddle.to_tensor(images, dtype='float32')
self.images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
# Padded dimensions
nh, nw = self.images.shape[-2:]
# These two store the same information in different formats
self.masks = paddle.zeros((t, 1, nh, nw), dtype='int64')
self.np_masks = np.zeros((t, h, w), dtype=np.int64)
if self.prob is None:
self.prob = paddle.zeros(
(self.k + 1, t, 1, nh, nw), dtype='float32')
self.prob[0] = 1e-7
else:
k, t, c, nh, nw = self.prob.shape
if (self.k + 1) != k:
add_obj = abs(self.k + 1 - k)
new_prob = paddle.zeros([add_obj, t, c, nh, nw]) + 1e-7
self.prob = paddle.concat([self.prob, new_prob], axis=0)
self.t, self.h, self.w = t, h, w
self.nh, self.nw = nh, nw
self.kh = self.nh // 16
self.kw = self.nw // 16
def set_objects(self, num_objects):
self.k = num_objects
def get_image_buffered(self, idx):
if idx not in self.image_buf:
# Flush buffer
if len(self.image_buf) > self.i_buf_size:
self.image_buf = {}
self.image_buf[idx] = self.images[:, idx]
result = self.image_buf[idx]
return result
# return self.images[:, idx]
def get_query_kv_mask(self, idx, this_k, this_v):
# Queries' key/value never change, so we can buffer them here
if idx not in self.query_buf:
# Flush buffer
if len(self.query_buf) > self.q_buf_size:
self.query_buf = {}
result = calculate_segmentation(
self.prop_net_segm,
self.get_image_buffered(idx).numpy(),
this_k.numpy(), this_v.numpy())
mask = result[0]
quary = result[1]
return mask, quary
def do_pass(self, key_k, key_v, idx, forward=True, step_cb=None):
"""
Do a complete pass that includes propagation and fusion
key_k/key_v - memory feature of the starting frame
idx - Frame index of the starting frame
forward - forward/backward propagation
step_cb - Callback function used for GUI (progress bar) only
"""
# Pointer in the memory bank
num_certain_keys = self.certain_mem_k.shape[2]
m_front = num_certain_keys
# Determine the required size of the memory bank
if forward:
closest_ti = min([ti for ti in self.interacted
if ti > idx] + [self.t])
total_m = (closest_ti - idx - 1
) // self.mem_freq + 1 + num_certain_keys
else:
closest_ti = max([ti for ti in self.interacted if ti < idx] + [-1])
total_m = (idx - closest_ti - 1
) // self.mem_freq + 1 + num_certain_keys
K, CK, _, H, W = key_k.shape
_, CV, _, _, _ = key_v.shape
# Pre-allocate keys/values memory
keys = paddle.empty((K, CK, total_m, H, W), dtype='float32')
values = paddle.empty((K, CV, total_m, H, W), dtype='float32')
# Initial key/value passed in
keys[:, :, 0:num_certain_keys] = self.certain_mem_k
values[:, :, 0:num_certain_keys] = self.certain_mem_v
prev_in_mem = True
last_ti = idx
# Note that we never reach closest_ti, just the frame before it
if forward:
this_range = range(idx + 1, closest_ti)
step = +1
end = closest_ti - 1
else:
this_range = range(idx - 1, closest_ti, -1)
step = -1
end = closest_ti + 1
for ti in this_range:
if prev_in_mem:
this_k = keys[:, :, :m_front]
this_v = values[:, :, :m_front]
else:
this_k = keys[:, :, :m_front + 1]
this_v = values[:, :, :m_front + 1]
out_mask, quary_key = self.get_query_kv_mask(ti, this_k, this_v)
out_mask = aggregate_wbg(paddle.to_tensor(out_mask), keep_bg=True)
if ti != end:
keys[:, :, m_front:m_front +
1], values[:, :, m_front:m_front + 1] = calculate_memorize(
self.prop_net_memory,
self.get_image_buffered(ti).numpy(),
out_mask[1:].numpy())
if abs(ti - last_ti) >= self.mem_freq:
# Memorize the frame
m_front += 1
last_ti = ti
prev_in_mem = True
else:
prev_in_mem = False
# In-place fusion, maximizes the use of queried buffer
# esp. for long sequence where the buffer will be flushed
if (closest_ti != self.t) and (closest_ti != -1):
self.prob[:, ti] = self.fuse_one_frame(
closest_ti, idx, ti, self.prob[:, ti], out_mask, key_k,
quary_key)
else:
self.prob[:, ti] = out_mask
# Callback function for the GUI
if step_cb is not None:
step_cb()
return closest_ti
def fuse_one_frame(self, tc, tr, ti, prev_mask, curr_mask, mk16, qk16):
assert (tc < ti < tr or tr < ti < tc)
prob = paddle.zeros((self.k, 1, self.nh, self.nw), dtype='float32')
# Compute linear coefficients
nc = abs(tc - ti) / abs(tc - tr)
nr = abs(tr - ti) / abs(tc - tr)
dist = paddle.to_tensor([nc, nr], dtype='float32').unsqueeze(0)
for k in range(1, self.k + 1):
attn_map = self.prop_net_attn(mk16[k - 1:k], qk16,
self.pos_mask_diff[k:k + 1],
self.neg_mask_diff[k:k + 1])
w = calculate_fusion(self.fuse_net,
self.get_image_buffered(ti).numpy(),
prev_mask[k:k + 1].numpy(),
curr_mask[k:k + 1].numpy(),
attn_map.numpy(), dist.numpy())
w = paddle.to_tensor(w)
w = F.sigmoid(w)
prob[k - 1:k] = w
return aggregate_wbg(prob, keep_bg=True)
def interact(self, mask, idx, total_cb=None, step_cb=None):
"""
Interact -> Propagate -> Fuse
mask - One-hot mask of the interacted frame, background included
idx - Frame index of the interacted frame
total_cb, step_cb - Callback functions for the GUI
Return: all mask results in np format for DAVIS evaluation
"""
self.interacted.add(idx)
mask, _ = pad_divide_by(mask, 16, mask.shape[-2:])
# print('self.k is %d' % self.k)
self.mask_diff = mask - self.prob[:, idx]
self.pos_mask_diff = self.mask_diff.clip(0, 1)
self.neg_mask_diff = (-self.mask_diff).clip(0, 1)
self.prob[:, idx] = mask
key_k, key_v = calculate_memorize(self.prop_net_memory,
self.get_image_buffered(idx).numpy(),
mask[1:].numpy())
key_k = paddle.to_tensor(key_k).astype("float32")
key_v = paddle.to_tensor(key_v).astype('float32')
if self.certain_mem_k is None:
self.certain_mem_k = key_k
self.certain_mem_v = key_v
else:
K, CK, _, H, W = self.certain_mem_k.shape
CV = self.certain_mem_v.shape[1]
self.certain_mem_k = paddle.concat(
[self.certain_mem_k, paddle.zeros((self.k - K, CK, _, H, W))],
0)
self.certain_mem_v = paddle.concat(
[self.certain_mem_v, paddle.zeros((self.k - K, CV, _, H, W))],
0)
self.certain_mem_k = paddle.concat([self.certain_mem_k, key_k], 2)
self.certain_mem_v = paddle.concat([self.certain_mem_v, key_v], 2)
# self.certain_mem_k = key_k
# self.certain_mem_v = key_v
if total_cb is not None:
# Finds the total num. frames to process
front_limit = min([ti for ti in self.interacted
if ti > idx] + [self.t])
back_limit = max([ti for ti in self.interacted if ti < idx] + [-1])
total_num = front_limit - back_limit - 2 # -1 for shift, -1 for center frame
if total_num > 0:
total_cb(total_num)
with paddle.no_grad():
self.do_pass(key_k, key_v, idx, True, step_cb=step_cb)
self.do_pass(key_k, key_v, idx, False, step_cb=step_cb)
# This is a more memory-efficient argmax
for ti in range(self.t):
self.masks[ti] = paddle.argmax(self.prob[:, ti], axis=0)
out_masks = self.masks
# Trim paddings
if self.pad[2] + self.pad[3] > 0:
out_masks = out_masks[:, :, self.pad[2]:-self.pad[3], :]
if self.pad[0] + self.pad[1] > 0:
out_masks = out_masks[:, :, :, self.pad[0]:-self.pad[1]]
self.np_masks = (out_masks.detach().numpy()[:, 0]).astype(np.int64)
return self.np_masks
def update_mask_only(self, prob_mask, idx):
"""
Interaction only, no propagation/fusion
prob_mask - mask of the interacted frame, background included
idx - Frame index of the interacted frame
Return: all mask results in np format for DAVIS evaluation
"""
mask = paddle.argmax(prob_mask, 0)
self.masks[idx:idx + 1] = mask
# Mask - 1 * H * W
if self.pad[2] + self.pad[3] > 0:
mask = mask[:, self.pad[2]:-self.pad[3], :]
if self.pad[0] + self.pad[1] > 0:
mask = mask[:, :, self.pad[0]:-self.pad[1]]
mask = (mask.detach().numpy()[0]).astype(np.int64)
self.np_masks[idx:idx + 1] = mask
return self.np_masks
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.inference as paddle_infer
def load_model(model_path, param_path, use_gpu=None):
config = paddle_infer.Config(model_path, param_path)
if use_gpu is None:
if paddle.device.is_compiled_with_cuda(): # TODO: 可以使用GPU却返回False
use_gpu = True
else:
use_gpu = False
if not use_gpu:
config.enable_mkldnn()
# TODO: fluid要废弃了,研究判断方式
config.switch_ir_optim(True)
config.set_cpu_math_library_num_threads(10)
else:
config.enable_use_gpu(500, 0)
config.delete_pass("conv_elementwise_add_act_fuse_pass")
config.delete_pass("conv_elementwise_add2_act_fuse_pass")
config.delete_pass("conv_elementwise_add_fuse_pass")
config.switch_ir_optim()
config.enable_memory_optim()
# config = paddle_infer.Config(model_path, param_path)
# config.enable_mkldnn()
# config.switch_ir_optim(True)
# config.set_cpu_math_library_num_threads(10)
model = paddle_infer.create_predictor(config)
return model
def jit_load(path):
model = paddle.jit.load(path)
model.eval()
return model
def calculate_memorize(model, frame, masks):
input_names = model.get_input_names()
frame_handle = model.get_input_handle(input_names[0])
masks_handle = model.get_input_handle(input_names[1])
frame_handle.copy_from_cpu(frame)
masks_handle.copy_from_cpu(masks)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
output_handle2 = model.get_output_handle(output_names[1])
result2 = output_handle2.copy_to_cpu()
return result, result2
def calculate_segmentation(model, frame, keys, values):
input_names = model.get_input_names()
frame_handle = model.get_input_handle(input_names[0])
keys_handle = model.get_input_handle(input_names[1])
values_handle = model.get_input_handle(input_names[2])
frame_handle.copy_from_cpu(frame)
keys_handle.copy_from_cpu(keys)
values_handle.copy_from_cpu(values)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
output_handle2 = model.get_output_handle(output_names[1])
result2 = output_handle2.copy_to_cpu()
return result, result2
def calculate_attention(model, mk16, qk16, pos_mask, neg_mask):
input_names = model.get_input_names()
mk16_handle = model.get_input_handle(input_names[0])
qk16_handle = model.get_input_handle(input_names[1])
pos_handle = model.get_input_handle(input_names[2])
neg_handle = model.get_input_handle(input_names[3])
mk16_handle.copy_from_cpu(mk16)
qk16_handle.copy_from_cpu(qk16)
pos_handle.copy_from_cpu(pos_mask)
neg_handle.copy_from_cpu(neg_mask)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
return result
def calculate_fusion(model, im, seg1, seg2, attn, time):
input_names = model.get_input_names()
im_handle = model.get_input_handle(input_names[0])
seg1_handle = model.get_input_handle(input_names[1])
seg2_handle = model.get_input_handle(input_names[2])
attn_handle = model.get_input_handle(input_names[3])
time_handle = model.get_input_handle(input_names[4])
im_handle.copy_from_cpu(im)
seg1_handle.copy_from_cpu(seg1)
seg2_handle.copy_from_cpu(seg2)
attn_handle.copy_from_cpu(attn)
time_handle.copy_from_cpu(time)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
return result
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE
from paddle.vision import transforms
im_mean = (124, 116, 104)
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], )
inv_im_trans = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225], )
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.vision import transforms
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
inv_im_trans = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
def images_to_paddle(frames):
frames = paddle.to_tensor(frames.transpose([0, 3, 1, 2])).astype(
'float32').unsqueeze(0) / 255
b, t, c, h, w = frames.shape
for ti in range(t):
frames[0, ti] = im_normalization(frames[0, ti])
return frames
def compute_tensor_iu(seg, gt):
intersection = (seg & gt).astype('float32').sum()
union = (seg | gt).astype('float32').sum()
return intersection, union
def compute_np_iu(seg, gt):
intersection = (seg & gt).astype(np.float32).sum()
union = (seg | gt).astype(np.float32).sum()
return intersection, union
def compute_tensor_iou(seg, gt):
intersection, union = compute_tensor_iu(seg, gt)
iou = (intersection + 1e-6) / (union + 1e-6)
return iou
def compute_np_iou(seg, gt):
intersection, union = compute_np_iu(seg, gt)
iou = (intersection + 1e-6) / (union + 1e-6)
return iou
def compute_multi_class_iou(seg, gt):
# seg -> k*h*w
# gt -> k*1*h*w
num_classes = gt.shape[0]
pred_idx = paddle.argmax(seg, axis=0)
iou_sum = 0
for ki in range(num_classes):
# seg includes BG class
iou_sum += compute_tensor_iou(pred_idx == (ki + 1), gt[ki, 0] > 0.5)
return (iou_sum + 1e-6) / (num_classes + 1e-6)
def compute_multi_class_iou_idx(seg, gt):
# seg -> h*w
# gt -> k*h*w
num_classes = gt.shape[0]
iou_sum = 0
for ki in range(num_classes):
# seg includes BG class
iou_sum += compute_np_iou(seg == (ki + 1), gt[ki] > 0.5)
return (iou_sum + 1e-6) / (num_classes + 1e-6)
def compute_multi_class_iou_both_idx(seg, gt):
# seg -> h*w
# gt -> h*w
num_classes = gt.max()
iou_sum = 0
for ki in range(1, num_classes + 1):
iou_sum += compute_np_iou(seg == ki, gt == ki)
return (iou_sum + 1e-6) / (num_classes + 1e-6)
# STM
def pad_divide_by(in_img, d, in_size=None):
if in_size is None:
h, w = in_img.shape[-2:]
else:
h, w = in_size
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
if len(in_img.shape) == 5:
N, B, C, H, W = in_img.shape
in_img = in_img.reshape([-1, C, H, W])
out = F.pad(in_img, pad_array, data_format='NCHW').unsqueeze(0)
else:
out = F.pad(in_img, pad_array, data_format='NCHW')
return out, pad_array
def unpad(img, pad):
if pad[2] + pad[3] > 0:
img = img[:, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, :, pad[0]:-pad[1]]
return img
def unpad_3dim(img, pad):
if pad[2] + pad[3] > 0:
img = img[:, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, pad[0]:-pad[1]]
return img
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE
import glob
import os
import cv2
import numpy as np
import paddle
import paddle.nn.functional as F
from PIL import Image
from eiseg.util.vis import get_palette
def load_video(path, min_side=480):
frame_list = []
cap = cv2.VideoCapture(path)
while (cap.isOpened()):
_, frame = cap.read()
if frame is None:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if min_side:
h, w = frame.shape[:2]
new_w = (w * min_side // min(w, h))
new_h = (h * min_side // min(w, h))
frame = cv2.resize(
frame, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
frame_list.append(frame)
frames = np.stack(frame_list, axis=0)
fps = cap.get(cv2.CAP_PROP_FPS)
return frames, fps
def load_masks(path, min_side=None):
fnames = sorted(glob.glob(os.path.join(path, '*.png')))
frame_list = []
first_frame = np.array(Image.open(fnames[0]))
binary_mask = (first_frame.max() == 255)
for i, fname in enumerate(fnames):
if min_side:
image = Image.open(fname)
w, h = image.size
new_w = (w * min_side // min(w, h))
new_h = (h * min_side // min(w, h))
frame_list.append(
np.array(
image.resize((new_w, new_h), Image.NEAREST),
dtype=np.uint8))
else:
frame_list.append(np.array(Image.open(fname), dtype=np.uint8))
frames = np.stack(frame_list, axis=0)
if binary_mask:
frames = (frames > 128).astype(np.uint8)
return frames
def overlay_davis(image, mask, alpha=0.5, palette=None):
""" Overlay segmentation on top of RGB image. from davis official"""
result = image.copy()
if mask is not None:
if not palette:
palette = get_palette(np.max(mask) + 1)
palette = np.array(palette)
rgb_mask = palette[mask.astype(np.uint8)]
mask_region = (mask > 0).astype(np.uint8)
result = (result * (1 - mask_region[:, :, np.newaxis]) + (1 - alpha) *
mask_region[:, :, np.newaxis] * result + alpha * rgb_mask)
result = result.astype(np.uint8)
return result
def aggregate_wbg(prob, keep_bg=False, hard=False):
k, _, h, w = prob.shape
new_prob = paddle.concat(
[paddle.prod(
1 - prob, axis=0, keepdim=True), prob], 0).clip(1e-7, 1 - 1e-7)
logits = paddle.log((new_prob / (1 - new_prob)))
if hard:
logits *= 1000
if keep_bg:
return F.softmax(logits, axis=0)
else:
return F.softmax(logits, axis=0)[1:]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment