video_tools.py 2.81 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# The video propagation and fusion code was heavily based on https://github.com/hkchengrex/MiVOS
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/hkchengrex/MiVOS/blob/main/LICENSE

import glob
import os

import cv2
import numpy as np
import paddle
import paddle.nn.functional as F
from PIL import Image

from eiseg.util.vis import get_palette


def load_video(path, min_side=480):
    frame_list = []
    cap = cv2.VideoCapture(path)
    while (cap.isOpened()):
        _, frame = cap.read()
        if frame is None:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if min_side:
            h, w = frame.shape[:2]
            new_w = (w * min_side // min(w, h))
            new_h = (h * min_side // min(w, h))
            frame = cv2.resize(
                frame, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        frame_list.append(frame)
    frames = np.stack(frame_list, axis=0)
    fps = cap.get(cv2.CAP_PROP_FPS)
    return frames, fps


def load_masks(path, min_side=None):
    fnames = sorted(glob.glob(os.path.join(path, '*.png')))
    frame_list = []

    first_frame = np.array(Image.open(fnames[0]))
    binary_mask = (first_frame.max() == 255)

    for i, fname in enumerate(fnames):
        if min_side:
            image = Image.open(fname)
            w, h = image.size
            new_w = (w * min_side // min(w, h))
            new_h = (h * min_side // min(w, h))
            frame_list.append(
                np.array(
                    image.resize((new_w, new_h), Image.NEAREST),
                    dtype=np.uint8))
        else:
            frame_list.append(np.array(Image.open(fname), dtype=np.uint8))

    frames = np.stack(frame_list, axis=0)
    if binary_mask:
        frames = (frames > 128).astype(np.uint8)
    return frames


def overlay_davis(image, mask, alpha=0.5, palette=None):
    """ Overlay segmentation on top of RGB image. from davis official"""
    result = image.copy()
    if mask is not None:
        if not palette:
            palette = get_palette(np.max(mask) + 1)
        palette = np.array(palette)
        rgb_mask = palette[mask.astype(np.uint8)]
        mask_region = (mask > 0).astype(np.uint8)
        result = (result * (1 - mask_region[:, :, np.newaxis]) + (1 - alpha) *
                  mask_region[:, :, np.newaxis] * result + alpha * rgb_mask)
        result = result.astype(np.uint8)
    return result


def aggregate_wbg(prob, keep_bg=False, hard=False):
    k, _, h, w = prob.shape
    new_prob = paddle.concat(
        [paddle.prod(
            1 - prob, axis=0, keepdim=True), prob], 0).clip(1e-7, 1 - 1e-7)
    logits = paddle.log((new_prob / (1 - new_prob)))

    if hard:
        logits *= 1000

    if keep_bg:
        return F.softmax(logits, axis=0)
    else:
        return F.softmax(logits, axis=0)[1:]