Commit 0d97cc8c authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
Pipeline #316 failed with stages
in 0 seconds
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .rs_grid import RSGrids
from .grid import Grids, checkOpenGrid
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from PIL import Image
def checkOpenGrid(img, thumbnail_min):
H, W = img.shape[:2]
if max(H, W) <= thumbnail_min:
return False
else:
return True
class Grids:
def __init__(self, img, gridSize=(512, 512), overlap=(24, 24)):
self.clear()
self.detimg = img
self.gridSize = np.array(gridSize)
self.overlap = np.array(overlap)
def clear(self):
# 图像HWC格式
self.detimg = None # 宫格初始图像
self.grid_init = False # 是否初始化了宫格
# self.imagesGrid = [] # 图像宫格
self.mask_grids = [] # 标签宫格
self.json_labels = [] # 保存标签
self.grid_count = None # (row count, col count)
self.curr_idx = None # (current row, current col)
def createGrids(self):
# 计算宫格横纵向格数
imgSize = np.array(self.detimg.shape[:2])
grid_count = np.ceil((imgSize + self.overlap) / self.gridSize)
self.grid_count = grid_count = grid_count.astype("uint16")
# ul = self.overlap - self.gridSize
# for row in range(grid_count[0]):
# ul[0] = ul[0] + self.gridSize[0] - self.overlap[0]
# for col in range(grid_count[1]):
# ul[1] = ul[1] + self.gridSize[1] - self.overlap[1]
# lr = ul + self.gridSize
# # print("ul, lr", ul, lr)
# # 扩充
# det_tmp = self.detimg[ul[0]: lr[0], ul[1]: lr[1]]
# tmp = np.zeros((self.gridSize[0], self.gridSize[1], self.detimg.shape[-1]))
# tmp[:det_tmp.shape[0], :det_tmp.shape[1], :] = det_tmp
# self.imagesGrid.append(tmp)
# self.mask_grids = [[np.zeros(self.gridSize)] * grid_count[1]] * grid_count[0] # 不能用浅拷贝
self.mask_grids = [
[np.zeros(self.gridSize) for _ in range(grid_count[1])]
for _ in range(grid_count[0])
]
# print(len(self.mask_grids), len(self.mask_grids[0]))
self.grid_init = True
return list(grid_count)
def getGrid(self, row, col):
gridIdx = np.array([row, col])
ul = gridIdx * (self.gridSize - self.overlap)
lr = ul + self.gridSize
img = self.detimg[ul[0]:lr[0], ul[1]:lr[1]]
mask = self.mask_grids[row][col]
self.curr_idx = (row, col)
return img, mask
def splicingList(self, save_path):
"""
将slide的out进行拼接,raw_size保证恢复到原状
"""
imgs = self.mask_grids
raw_size = self.detimg.shape[:2]
# h, w = None, None
# for i in range(len(imgs)):
# for j in range(len(imgs[i])):
# im = imgs[i][j]
# if im is not None:
# h, w = im.shape[:2]
# break
# if h is None and w is None:
# return False
h, w = self.gridSize
row = math.ceil(raw_size[0] / h)
col = math.ceil(raw_size[1] / w)
result_1 = np.zeros((h * row, w * col), dtype=np.uint8)
result_2 = result_1.copy()
# k = 0
for i in range(row):
for j in range(col):
ih, iw = imgs[i][j].shape[:2]
im = np.zeros(self.gridSize)
im[:ih, :iw] = imgs[i][j]
start_h = (i * h) if i == 0 else (i * (h - self.overlap[0]))
end_h = start_h + h
start_w = (j * w) if j == 0 else (j * (w - self.overlap[1]))
end_w = start_w + w
# 单区自己,重叠取或
if (i + j) % 2 == 0:
result_1[start_h:end_h, start_w:end_w] = im
else:
result_2[start_h:end_h, start_w:end_w] = im
result = np.where(result_2 != 0, result_2, result_1)
result = result[:raw_size[0], :raw_size[1]]
Image.fromarray(result).save(save_path, "PNG")
return result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from typing import List, Tuple
from eiseg.plugin.remotesensing.raster import Raster
class RSGrids:
def __init__(self, raset: Raster) -> None:
""" 在EISeg中用于处理遥感栅格数据的宫格类.
参数:
tif_path (str): GTiff数据的路径.
show_band (Union[List[int], Tuple[int]], optional): 用于RGB合成显示的波段. 默认为 [1, 1, 1].
grid_size (Union[List[int], Tuple[int]], optional): 切片大小. 默认为 [512, 512].
overlap (Union[List[int], Tuple[int]], optional): 重叠区域的大小. 默认为 [24, 24].
"""
super(RSGrids, self).__init__()
self.raster = raset
self.clear()
def clear(self) -> None:
self.mask_grids = [] # 标签宫格
self.json_labels = [] # 保存标签
self.grid_count = None # (row count, col count)
self.curr_idx = None # (current row, current col)
def createGrids(self) -> List[int]:
img_size = (self.raster.geoinfo.ysize, self.raster.geoinfo.xsize)
grid_count = np.ceil(
(img_size + self.raster.overlap) / self.raster.grid_size)
self.grid_count = grid_count = grid_count.astype("uint16")
self.mask_grids = [[np.zeros(self.raster.grid_size) \
for _ in range(grid_count[1])] for _ in range(grid_count[0])]
return list(grid_count)
def getGrid(self, row: int, col: int) -> Tuple[np.ndarray]:
img, _ = self.raster.getGrid(row, col)
mask = self.mask_grids[row][col]
self.curr_idx = (row, col)
return img, mask
def splicingList(self, save_path: str) -> np.ndarray:
mask = self.raster.saveMaskbyGrids(self.mask_grids, save_path,
self.raster.geoinfo)
return mask
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .imgtools import *
from .shape import *
from .raster import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
from skimage import exposure
# 2%线性拉伸
def two_percentLinear(image: np.ndarray, max_out: int=255,
min_out: int=0) -> np.ndarray:
b, g, r = cv2.split(image)
def __gray_process(gray, maxout=max_out, minout=min_out):
high_value = np.percentile(gray, 98) # 取得98%直方图处对应灰度
low_value = np.percentile(gray, 2)
truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
processed_gray = ((truncated_gray - low_value) /
(high_value - low_value)) * (maxout - minout)
return processed_gray
r_p = __gray_process(r)
g_p = __gray_process(g)
b_p = __gray_process(b)
result = cv2.merge((b_p, g_p, r_p))
return np.uint8(result)
# 简单图像标准化
def sample_norm(image: np.ndarray) -> np.ndarray:
stretches = []
if len(image.shape) == 3:
for b in range(image.shape[-1]):
stretched = exposure.equalize_hist(image[:, :, b])
stretched /= float(np.max(stretched))
stretches.append(stretched)
stretched_img = np.stack(stretches, axis=2)
else: # if len(image.shape) == 2
stretched_img = exposure.equalize_hist(image)
return np.uint8(stretched_img * 255)
# 计算缩略图
def get_thumbnail(image: np.ndarray, range: int=2000,
max_size: int=1000) -> np.ndarray:
h, w = image.shape[:2]
if h >= range or w >= range:
if h >= w:
image = cv2.resize(image, (int(max_size / h * w), max_size))
else:
image = cv2.resize(image, (max_size, int(max_size / w * h)))
return image
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import numpy as np
import cv2
import math
from typing import List, Dict, Tuple, Union
from collections import defaultdict
from easydict import EasyDict as edict
from .imgtools import sample_norm, two_percentLinear, get_thumbnail
def check_rasterio() -> bool:
try:
import rasterio
return True
except:
return False
IMPORT_STATE = False
if check_rasterio():
import rasterio
from rasterio.windows import Window
IMPORT_STATE = True
class Raster:
def __init__(self,
tif_path: str,
show_band: Union[List[int], Tuple[int]]=[1, 1, 1],
open_grid: bool=False,
grid_size: Union[List[int], Tuple[int]]=[512, 512],
overlap: Union[List[int], Tuple[int]]=[24, 24]) -> None:
""" 在EISeg中用于处理遥感栅格数据的类.
参数:
tif_path (str): GTiff数据的路径.
show_band (Union[List[int], Tuple[int]], optional): 用于RGB合成显示的波段. 默认为 [1, 1, 1].
open_grid (bool, optional): 是否打开了宫格切片功能. 默认为 False.
grid_size (Union[List[int], Tuple[int]], optional): 切片大小. 默认为 [512, 512].
overlap (Union[List[int], Tuple[int]], optional): 重叠区域的大小. 默认为 [24, 24].
"""
super(Raster, self).__init__()
if IMPORT_STATE is False:
raise ("Can't import rasterio!")
if osp.exists(tif_path):
self.src_data = rasterio.open(tif_path)
self.geoinfo = self.__getRasterInfo()
self.show_band = list(show_band)
self.grid_size = np.array(grid_size)
self.overlap = np.array(overlap)
self.open_grid = open_grid
else:
raise ("{0} not exists!".format(tif_path))
self.thumbnail_min = 2000
def __del__(self) -> None:
self.src_data.close()
def __getRasterInfo(self) -> Dict:
meta = self.src_data.meta
geoinfo = edict()
geoinfo.count = meta["count"]
geoinfo.dtype = meta["dtype"]
geoinfo.xsize = meta["width"]
geoinfo.ysize = meta["height"]
geoinfo.geotf = meta["transform"]
geoinfo.crs = meta["crs"]
if geoinfo.crs is not None:
geoinfo.crs_wkt = geoinfo.crs.wkt
else:
geoinfo.crs_wkt = None
return geoinfo
def checkOpenGrid(self, thumbnail_min: Union[int, None]) -> bool:
if isinstance(thumbnail_min, int):
self.thumbnail_min = thumbnail_min
if max(self.geoinfo.xsize, self.geoinfo.ysize) <= self.thumbnail_min:
self.open_grid = False
else:
self.open_grid = True
return self.open_grid
def setBand(self, bands: Union[List[int], Tuple[int]]) -> None:
self.show_band = list(bands)
# def __analysis_proj4(self) -> str:
# proj4 = self.geoinfo.crs.wkt # TODO: 解析为proj4
# ap_dict = defaultdict(str)
# dinf = proj4.split("+")
# for df in dinf:
# kv = df.strip().split("=")
# if len(kv) == 2:
# k, v = kv
# ap_dict[k] = v
# return str("● 投影:{0}\n● 基准:{1}\n● 单位:{2}".format(
# ap_dict["proj"], ap_dict["datum"], ap_dict["units"])
# )
def showGeoInfo(self) -> str:
# return str("● 波段数:{0}\n● 数据类型:{1}\n● 行数:{2}\n● 列数:{3}\n{4}".format(
# self.geoinfo.count, self.geoinfo.dtype, self.geoinfo.xsize,
# self.geoinfo.ysize, self.__analysis_proj4())
# )
if self.geoinfo.crs is not None:
crs = str(self.geoinfo.crs.to_string().split(":")[-1])
else:
crs = "None"
return (str(self.geoinfo.count), str(self.geoinfo.dtype),
str(self.geoinfo.xsize), str(self.geoinfo.ysize), crs)
def getArray(self) -> Tuple[np.ndarray]:
rgb = []
if not self.open_grid:
for b in self.show_band:
rgb.append(self.src_data.read(b))
geotf = self.geoinfo.geotf
else:
for b in self.show_band:
rgb.append(
get_thumbnail(self.src_data.read(b), self.thumbnail_min))
geotf = None
ima = np.stack(rgb, axis=2) # cv2.merge(rgb)
if self.geoinfo["dtype"] != "uint8":
ima = sample_norm(ima)
return two_percentLinear(ima), geotf
def getGrid(self, row: int, col: int) -> Tuple[np.ndarray]:
if self.open_grid is False:
return self.getArray()
grid_idx = np.array([row, col])
ul = grid_idx * (self.grid_size - self.overlap)
lr = ul + self.grid_size
window = Window(ul[1], ul[0], (lr[1] - ul[1]), (lr[0] - ul[0]))
rgb = []
for b in self.show_band:
rgb.append(self.src_data.read(b, window=window))
win_tf = self.src_data.window_transform(window)
ima = cv2.merge([np.uint16(c) for c in rgb])
if self.geoinfo["dtype"] == "uint32":
ima = sample_norm(ima)
return two_percentLinear(ima), win_tf
def saveMask(self,
img: np.array,
save_path: str,
geoinfo: Union[Dict, None]=None,
count: int=1) -> None:
if geoinfo is None:
geoinfo = self.geoinfo
new_meta = self.src_data.meta.copy()
new_meta.update({
"driver": "GTiff",
"width": geoinfo.xsize,
"height": geoinfo.ysize,
"count": count,
"dtype": geoinfo.dtype,
"crs": geoinfo.crs,
"transform": geoinfo.geotf[:6],
"nodata": 0
})
img = np.nan_to_num(img).astype("int16")
with rasterio.open(save_path, "w", **new_meta) as tf:
if count == 1:
tf.write(img, indexes=1)
else:
for i in range(count):
tf.write(img[:, :, i], indexes=(i + 1))
def saveMaskbyGrids(self,
img_list: List[List[np.ndarray]],
save_path: Union[str, None]=None,
geoinfo: Union[Dict, None]=None) -> np.ndarray:
if geoinfo is None:
geoinfo = self.geoinfo
raw_size = (geoinfo.ysize, geoinfo.xsize)
h, w = self.grid_size
row = math.ceil(raw_size[0] / h)
col = math.ceil(raw_size[1] / w)
result_1 = np.zeros((h * row, w * col), dtype=np.uint8)
result_2 = result_1.copy()
for i in range(row):
for j in range(col):
ih, iw = img_list[i][j].shape[:2]
im = np.zeros(self.grid_size)
im[:ih, :iw] = img_list[i][j]
start_h = (i * h) if i == 0 else (i * (h - self.overlap[0]))
end_h = start_h + h
start_w = (j * w) if j == 0 else (j * (w - self.overlap[1]))
end_w = start_w + w
# 单区自己,重叠取或
if (i + j) % 2 == 0:
result_1[start_h:end_h, start_w:end_w] = im
else:
result_2[start_h:end_h, start_w:end_w] = im
result = np.where(result_2 != 0, result_2, result_1)
result = result[:raw_size[0], :raw_size[1]]
if save_path is not None:
self.saveMask(result, save_path, geoinfo)
return result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
def check_gdal() -> bool:
try:
import gdal
except:
try:
from osgeo import gdal
except ImportError:
return False
return True
IMPORT_STATE = False
if check_gdal():
try:
import gdal
import osr
import ogr
except:
from osgeo import gdal, osr, ogr
IMPORT_STATE = True
# 保存shp文件
def save_shp(shp_path: str, tif_path: str, ignore_index: int=0) -> str:
if IMPORT_STATE == True:
ds = gdal.Open(tif_path)
srcband = ds.GetRasterBand(1)
maskband = srcband.GetMaskBand()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
ogr.RegisterAll()
drv = ogr.GetDriverByName("ESRI Shapefile")
if osp.exists(shp_path):
os.remove(shp_path)
dst_ds = drv.CreateDataSource(shp_path)
prosrs = osr.SpatialReference(wkt=ds.GetProjection())
dst_layer = dst_ds.CreateLayer(
"segmentation", geom_type=ogr.wkbPolygon, srs=prosrs)
dst_fieldname = "DN"
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
lyr = dst_ds.GetLayer()
lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
for holes in lyr:
lyr.DeleteFeature(holes.GetFID())
dst_ds.Destroy()
ds = None
return "Dataset creation successfully!"
else:
raise ImportError("can't import gdal, osr, ogr!")
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS # Users should be careful about adopting these functions in any commercial matters. # https://github.com/hkchengrex/MiVOS/blob/main/LICENSE from .inference_core import InferenceCore from .video_tools import overlay_davis
\ No newline at end of file
This diff is collapsed.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.inference as paddle_infer
def load_model(model_path, param_path, use_gpu=None):
config = paddle_infer.Config(model_path, param_path)
if use_gpu is None:
if paddle.device.is_compiled_with_cuda(): # TODO: 可以使用GPU却返回False
use_gpu = True
else:
use_gpu = False
if not use_gpu:
config.enable_mkldnn()
# TODO: fluid要废弃了,研究判断方式
config.switch_ir_optim(True)
config.set_cpu_math_library_num_threads(10)
else:
config.enable_use_gpu(500, 0)
config.delete_pass("conv_elementwise_add_act_fuse_pass")
config.delete_pass("conv_elementwise_add2_act_fuse_pass")
config.delete_pass("conv_elementwise_add_fuse_pass")
config.switch_ir_optim()
config.enable_memory_optim()
# config = paddle_infer.Config(model_path, param_path)
# config.enable_mkldnn()
# config.switch_ir_optim(True)
# config.set_cpu_math_library_num_threads(10)
model = paddle_infer.create_predictor(config)
return model
def jit_load(path):
model = paddle.jit.load(path)
model.eval()
return model
def calculate_memorize(model, frame, masks):
input_names = model.get_input_names()
frame_handle = model.get_input_handle(input_names[0])
masks_handle = model.get_input_handle(input_names[1])
frame_handle.copy_from_cpu(frame)
masks_handle.copy_from_cpu(masks)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
output_handle2 = model.get_output_handle(output_names[1])
result2 = output_handle2.copy_to_cpu()
return result, result2
def calculate_segmentation(model, frame, keys, values):
input_names = model.get_input_names()
frame_handle = model.get_input_handle(input_names[0])
keys_handle = model.get_input_handle(input_names[1])
values_handle = model.get_input_handle(input_names[2])
frame_handle.copy_from_cpu(frame)
keys_handle.copy_from_cpu(keys)
values_handle.copy_from_cpu(values)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
output_handle2 = model.get_output_handle(output_names[1])
result2 = output_handle2.copy_to_cpu()
return result, result2
def calculate_attention(model, mk16, qk16, pos_mask, neg_mask):
input_names = model.get_input_names()
mk16_handle = model.get_input_handle(input_names[0])
qk16_handle = model.get_input_handle(input_names[1])
pos_handle = model.get_input_handle(input_names[2])
neg_handle = model.get_input_handle(input_names[3])
mk16_handle.copy_from_cpu(mk16)
qk16_handle.copy_from_cpu(qk16)
pos_handle.copy_from_cpu(pos_mask)
neg_handle.copy_from_cpu(neg_mask)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
return result
def calculate_fusion(model, im, seg1, seg2, attn, time):
input_names = model.get_input_names()
im_handle = model.get_input_handle(input_names[0])
seg1_handle = model.get_input_handle(input_names[1])
seg2_handle = model.get_input_handle(input_names[2])
attn_handle = model.get_input_handle(input_names[3])
time_handle = model.get_input_handle(input_names[4])
im_handle.copy_from_cpu(im)
seg1_handle.copy_from_cpu(seg1)
seg2_handle.copy_from_cpu(seg2)
attn_handle.copy_from_cpu(attn)
time_handle.copy_from_cpu(time)
model.run()
output_names = model.get_output_names()
output_handle = model.get_output_handle(output_names[0])
result = output_handle.copy_to_cpu()
return result
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE
from paddle.vision import transforms
im_mean = (124, 116, 104)
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], )
inv_im_trans = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225], )
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.vision import transforms
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
inv_im_trans = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
def images_to_paddle(frames):
frames = paddle.to_tensor(frames.transpose([0, 3, 1, 2])).astype(
'float32').unsqueeze(0) / 255
b, t, c, h, w = frames.shape
for ti in range(t):
frames[0, ti] = im_normalization(frames[0, ti])
return frames
def compute_tensor_iu(seg, gt):
intersection = (seg & gt).astype('float32').sum()
union = (seg | gt).astype('float32').sum()
return intersection, union
def compute_np_iu(seg, gt):
intersection = (seg & gt).astype(np.float32).sum()
union = (seg | gt).astype(np.float32).sum()
return intersection, union
def compute_tensor_iou(seg, gt):
intersection, union = compute_tensor_iu(seg, gt)
iou = (intersection + 1e-6) / (union + 1e-6)
return iou
def compute_np_iou(seg, gt):
intersection, union = compute_np_iu(seg, gt)
iou = (intersection + 1e-6) / (union + 1e-6)
return iou
def compute_multi_class_iou(seg, gt):
# seg -> k*h*w
# gt -> k*1*h*w
num_classes = gt.shape[0]
pred_idx = paddle.argmax(seg, axis=0)
iou_sum = 0
for ki in range(num_classes):
# seg includes BG class
iou_sum += compute_tensor_iou(pred_idx == (ki + 1), gt[ki, 0] > 0.5)
return (iou_sum + 1e-6) / (num_classes + 1e-6)
def compute_multi_class_iou_idx(seg, gt):
# seg -> h*w
# gt -> k*h*w
num_classes = gt.shape[0]
iou_sum = 0
for ki in range(num_classes):
# seg includes BG class
iou_sum += compute_np_iou(seg == (ki + 1), gt[ki] > 0.5)
return (iou_sum + 1e-6) / (num_classes + 1e-6)
def compute_multi_class_iou_both_idx(seg, gt):
# seg -> h*w
# gt -> h*w
num_classes = gt.max()
iou_sum = 0
for ki in range(1, num_classes + 1):
iou_sum += compute_np_iou(seg == ki, gt == ki)
return (iou_sum + 1e-6) / (num_classes + 1e-6)
# STM
def pad_divide_by(in_img, d, in_size=None):
if in_size is None:
h, w = in_img.shape[-2:]
else:
h, w = in_size
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
if len(in_img.shape) == 5:
N, B, C, H, W = in_img.shape
in_img = in_img.reshape([-1, C, H, W])
out = F.pad(in_img, pad_array, data_format='NCHW').unsqueeze(0)
else:
out = F.pad(in_img, pad_array, data_format='NCHW')
return out, pad_array
def unpad(img, pad):
if pad[2] + pad[3] > 0:
img = img[:, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, :, pad[0]:-pad[1]]
return img
def unpad_3dim(img, pad):
if pad[2] + pad[3] > 0:
img = img[:, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, pad[0]:-pad[1]]
return img
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment