conversation_dataset.py 3.74 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import copy
import json
import logging
from datetime import datetime
from typing import Dict

import torch

from ovis.train.dataset.multimodal_dataset import MultimodalDataset
from ovis.util.constants import VIDEO_TOKEN, IMAGE_TOKEN, IGNORE_ID
from ovis.util.utils import rank0_print


class ConversationDataset(MultimodalDataset):
    def load(self):
        rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
        with open(self.meta_file, 'r', encoding='utf-8') as f:
            samples = json.load(f)
        rank0_print(f'#samples: {len(samples)}')
        rank0_print(f'sample: {samples[0]}')
        rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
        return samples

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        sample = self.samples[i]
        conversations = sample["conversations"]

        # try:
        images = None
        videos = None
        n_image_or_frame = 0
        if 'image' in sample:
            images = []
            image_paths = sample['image']
            if isinstance(image_paths, str):
                image_paths = [image_paths]
            for image_path in image_paths:
                image, last_e = self.read_image(image_path)
                assert image is not None, f"Failed to read image from {image_path}"
                images.append(image)
            n_image_or_frame = len(images)
        elif 'video' in sample or 'video_frames' in sample:
            video, last_e = self.read_video(sample, min_frames=self.min_frames, max_frames=self.max_frames)
            video_path = sample.get('video') or sample.get('video_frames')
            assert video is not None, f"Failed to read video from {video_path}"
            videos = [video]
            n_image_or_frame = len(video)

        if images is None and videos is None:
            min_pixels = 0
            max_pixels = 0
        elif videos is not None:
            min_pixels = self.training_args.video_min_pixels
            max_pixels = self.training_args.video_max_pixels
        elif len(images) == 1:
            min_pixels = self.training_args.single_image_min_pixels
            max_pixels = self.training_args.single_image_max_pixels
        else:
            min_pixels = self.training_args.multiple_image_min_pixels
            max_pixels = self.training_args.multiple_image_max_pixels

        if min_pixels < 0:
            min_pixels = self.training_args.single_image_min_pixels
        if max_pixels < 0:
            max_pixels = max(min_pixels, self.training_args.single_image_max_pixels // n_image_or_frame)

        prompt, input_ids, pixel_values, grid_thws, labels = self.model.preprocess_inputs(
            conversations,
            images=images,
            videos=videos,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
            generation_preface=None,
            return_labels=True,
        )
        if pixel_values is None:
            input_ids, pixel_values, grid_thws, labels = self.truncate_inputs(
                input_ids, pixel_values, grid_thws, labels, max_length=self.training_args.text_max_length
            )
        else:
            input_ids, pixel_values, grid_thws, labels = self.truncate_inputs(
                input_ids, pixel_values, grid_thws, labels, max_length=self.training_args.multimodal_max_length
            )
        assert self.text_tokenizer.pad_token_id not in input_ids, \
            "The sample's text contains a padding token: `{self.text_tokenizer.pad_token}`"

        del sample
        return dict(
            input_ids=input_ids,
            pixel_values=pixel_values,
            grid_thws=grid_thws,
            attention_mask=torch.full_like(input_ids, fill_value=True, dtype=torch.bool),
            labels=labels
        )