"docs/vscode:/vscode.git/clone" did not exist on "19b21010e0e85a58404a1978e43959b671215c29"
optical_flow.py 1.86 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import cv2
import numpy as np
import torch

def rescale_tensor(tensor):
    min_val, max_val = torch.min(tensor), torch.max(tensor)
    tensor = (tensor - min_val) / (max_val - min_val) * 255.0
    tensor = tensor.clamp(0, 255)
    return tensor

def compute_optical_flow(video_tensor):
    B, C, T, _, _ = video_tensor.shape
    assert C == 3, "Input video tensor must have 3 channels (RGB)."

    video_tensor = rescale_tensor(video_tensor).float()

    forward_flow = []
    backward_flow = []

    for b in range(B):
        forward_flow_batch = []
        backward_flow_batch = []
        for t in range(T - 1):
            frame1 = video_tensor[b, :, t, :, :].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            frame2 = video_tensor[b, :, t + 1, :, :].permute(1, 2, 0).cpu().numpy().astype(np.uint8)

            frame1_gray = cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY)
            frame2_gray = cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY)

            flow_forward = cv2.calcOpticalFlowFarneback(
                frame1_gray, frame2_gray, None,
                0.5, 3, 15, 3, 5, 1.2, 0
            )
            forward_flow_batch.append(flow_forward)

            flow_backward = cv2.calcOpticalFlowFarneback(
                frame2_gray, frame1_gray, None,
                0.5, 3, 15, 3, 5, 1.2, 0
            )
            backward_flow_batch.append(flow_backward)

        forward_flow_batch = np.stack(forward_flow_batch, axis=0)  # [T-1, H, W, 2]
        backward_flow_batch = np.stack(backward_flow_batch, axis=0)  # [T-1, H, W, 2]

        forward_flow.append(forward_flow_batch)
        backward_flow.append(backward_flow_batch)

    forward_flow = np.stack(forward_flow, axis=0)  # [B, T-1, H, W, 2]
    backward_flow = np.stack(backward_flow, axis=0)  # [B, T-1, H, W, 2]

    return torch.tensor(forward_flow).permute(0, 4, 1, 2, 3), torch.tensor(backward_flow).permute(0, 4, 1, 2, 3)