nextqa.py 5.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import sys
from typing import List

import av
from datasets import load_dataset


def find_video_files(video_dir) -> List[str]:
    if os.path.isfile(video_dir):
        return [video_dir]

    video_files = []
    for root, dirs, files in os.walk(video_dir):
        for file in files:
            if file.endswith((".mp4", ".avi", ".mov")):
                video_files.append(os.path.join(root, file))
            # if file is dir
            elif os.path.isdir(file):
                video_files.extend(find_video_files(file))
    return video_files


def video_frames(video_path, max_frames) -> int:
    container = av.open(video_path)
    total_frames = container.streams.video[0].frames
    return min(total_frames, max_frames)


class Video:
    def __init__(self, video_path, num_frames):
        self.path = video_path
        self.num_frames = num_frames

    def __str__(self):
        return f"Video({self.path}, {self.num_frames})"

    def __iter__(self):
        return iter((self.path, self.num_frames))


class VideoPrompt(Video):
    def __init__(self, video_path, num_frames, prompt):
        super().__init__(video_path, num_frames)
        self.prompt = prompt

    def __str__(self):
        return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"

    def __iter__(self):
        return iter((self.path, self.num_frames, self.prompt))


class VideoLoader:
    pass


class VideoFileLoader(VideoLoader):
    """
    Load all the videos in a directory
    """

    def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
        super().__init__()
        self.video_dir = video_dir
        self.video_files = find_video_files(video_dir)
        self.batch_size = batch_size
        self.max_frames = max_frames
        print(f"batch_size: {batch_size}, max_frames: {max_frames}")

    def __iter__(self):  # (file, number of frames)
        if self.batch_size == 1:
            for video_file in self.video_files:
                yield Video(video_file, video_frames(video_file, self.max_frames))
        else:
            batch = []
            for video_file in self.video_files:
                video = Video(video_file, video_frames(video_file, self.max_frames))
                batch.append(video)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []


class NExTQALoader(VideoLoader):
    """
    Load vdideos and prompts from NExT dataset
    set: train, test or validation
    """

    def __init__(
        self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
    ):
        """
        task: 'MV' or 'OE'
        """
        super().__init__()
        self.task = task
        print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
        self.ds = load_dataset("lmms-lab/NExTQA", task)
        self.ds = self.ds[dset]

        # self.n = ds.num_rows
        self.video_dir = video_dir
        self.video_files = find_video_files(video_dir)
        self.video_to_path = dict()
        for video_file in self.video_files:
            video_id = video_file.split("/")[-1].split(".")[0]
            self.video_to_path[video_id] = video_file

        self.batch_size = batch_size
        self.max_frames = max_frames

    def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
        # Get video
        video_id = entry["video"]
        video_path = self.video_to_path[video_id]
        assert os.path.exists(video_path), f"Video not found: {video_path}"
        num_frames = min(entry["frame_count"], max_frames)
        video = Video(video_path, num_frames)
        prompt = entry["question"] + "?"
        if self.task == "MC":  # add choices
            prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
        return VideoPrompt(video_path, num_frames, prompt)

    def __iter__(self):
        if self.batch_size == 1:
            for entry in self.ds:
                yield self.get_video_prompt(entry, self.max_frames)
        else:
            batch = []
            for entry in self.ds:
                video = self.get_video_prompt(entry, self.max_frames)
                batch.append(video)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []


# main
if __name__ == "__main__":
    video_dir = "./videos"
    # video_loader = VideoFileLoader(video_dir, batch_size=16)
    # for batch in video_loader:
    #     print(f"Number of videos in batch: {len(batch)}")
    #     for video_file, num_frames in batch:
    #         print(f"Video: {video_file} number of frames: {num_frames}")

    video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
    for batch in video_loader:
        print(f"Number of videos in batch: {len(batch)}")
        for video_file, num_frames, prompt in batch:
            print(
                f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
            )
        # break
        # for video_file, prompt in batch:
        #     print(f"Video: {video_file} prompt: {prompt}")
        #     break