multimodal_dataset.py 6.7 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import json
import logging
import os
import traceback
from typing import Dict, Sequence, Union, List

import numpy as np
import torch
import moviepy.editor as mp
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer

from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import TrainingArguments
from ovis.util.constants import IGNORE_ID, BEGIN_LINE, END_LINE, VISUAL_ATOM_ID, INDICATOR_IDS


class MultimodalDataset(Dataset):
    def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
        self.name = name
        self.model = model
        self.training_args = training_args

        self.meta_file = info['meta_file']
        self.image_dir = info['image_dir']
        self.caption_template = info.get('caption_template', None)
        self.text_tokenizer = self.model.text_tokenizer
        self.visual_tokenizer = self.model.visual_tokenizer
        self.text_max_length = training_args.text_max_length
        self.min_frames = training_args.min_frames
        self.max_frames = training_args.max_frames

        self.samples = self.load()

    def load(self):
        raise NotImplementedError

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def __len__(self):
        return len(self.samples)

    def read_image(self, path):
        try:
            full_path = os.path.join(self.image_dir, path)
            image = Image.open(full_path).convert('RGB')
            return image, None
        except Exception as e:
            return None, e

    def read_video(self, sample, min_frames, max_frames):
        def _sampling_idx(_len, _min, _max):
            if _len < _min or _len > _max:
                tgt_len = _min if _len < _min else _max
                stride = _len / tgt_len
                sampled_ids = []
                for i in range(tgt_len):
                    start = int(np.round(stride * i))
                    end = int(np.round(stride * (i + 1)))
                    sampled_ids.append(min(_len - 1, (start + end) // 2))
                return sampled_ids
            else:
                return list(range(_len))

        if "video_frames" in sample:
            frames = []
            frames_paths = sample['video_frames']
            sampled_ids = _sampling_idx(len(frames_paths), min_frames, max_frames)
            for idx in sampled_ids:
                frame, last_e = self.read_image(os.path.join(self.image_dir, frames_paths[idx]))
                if frame is None:
                    return None, last_e
                frames.append(frame)
            return frames, None
        elif "video" in sample:
            video_path = os.path.join(self.image_dir, sample['video'])

            max_tries = 2
            last_e = None
            for _ in range(max_tries):
                try:
                    with mp.VideoFileClip(video_path) as clip:
                        total_frames = int(clip.fps * clip.duration)
                        sampled_ids = _sampling_idx(total_frames, min_frames, max_frames)
                        frames = [clip.get_frame(idx / clip.fps) for idx in sampled_ids]
                        frames = [Image.fromarray(frame, mode='RGB') for frame in frames]

                    if len(frames) == 0 or any(frame.size[0] < 5 or frame.size[1] < 5 for frame in frames):
                        raise ValueError("frames are empty or there exists very small frame")
                    return frames, None
                except Exception as e:
                    last_e = f"read video error: {e}\n detailed info: {traceback.format_exc()}"
            return None, last_e
        else:
            return None, RuntimeError(f"missing `video_frames` and `video` in sample: {json.dumps(sample)}")
        
    def truncate_inputs(
        self, input_ids, pixel_values, grid_thws, labels, max_length
    ):
        input_ids = input_ids[0, :max_length]
        labels = labels[0, :max_length]
        if input_ids[-1] in (VISUAL_ATOM_ID, INDICATOR_IDS[0], INDICATOR_IDS[2]):  # incomplete visual input
            last_text_id_pos = (input_ids >= 0).nonzero()[-1].item() + 1
            input_ids = input_ids[:last_text_id_pos]
            labels = labels[:last_text_id_pos]
        num_visual_atom = torch.eq(input_ids, VISUAL_ATOM_ID).sum().item()
        if num_visual_atom > 0:
            vit = self.model.visual_tokenizer.vit
            ratio = vit.config.temporal_patch_size * vit.config.hidden_stride ** 2
            cumsum_patches = grid_thws.prod(dim=1).cumsum(dim=0)
            last_grid_index = (cumsum_patches // ratio == num_visual_atom).nonzero().item()
            pixel_values = pixel_values[:cumsum_patches[last_grid_index]]
            grid_thws = grid_thws[:last_grid_index+1]
        else:
            pixel_values, grid_thws = None, None

        return input_ids, pixel_values, grid_thws, labels


class DataCollatorForMultimodalDataset:
    def __init__(self, text_tokenizer: PreTrainedTokenizer):
        self.text_tokenizer = text_tokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        keys = ("input_ids", "pixel_values", "grid_thws", "attention_mask", "labels")
        input_ids, pixel_values, grid_thws, attention_mask, labels = (
            tuple(instance[key] for instance in instances) for key in keys
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.text_tokenizer.pad_token_id
        )
        pixel_values = [x for x in pixel_values if x is not None]
        pixel_values = torch.cat(pixel_values, dim=0) if len(pixel_values) > 0 else None
        grid_thws = [x for x in grid_thws if x is not None]
        grid_thws = torch.cat(grid_thws, dim=0) if len(grid_thws) > 0 else None
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            attention_mask,
            batch_first=True,
            padding_value=False
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_ID
        )
        if 0 not in attention_mask:
            input_ids = F.pad(input_ids, (0, 1), value=self.text_tokenizer.pad_token_id)
            attention_mask = F.pad(attention_mask, (0, 1), value=False)
            labels = F.pad(labels, (0, 1), value=IGNORE_ID)
            
        if torch.all(labels == IGNORE_ID):
            logging.warning(f'[DataCollatorForMultimodalDataset] All samples in the current batch are ignored.')
        return dict(
            input_ids=input_ids,
            pixel_values=pixel_values,
            grid_thws=grid_thws,
            attention_mask=attention_mask,
            labels=labels
        )