__init__.py 2.2 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

import torch
import transformers
from dataclasses import dataclass, field

from vary.utils.constants import *


@dataclass
class DataCollatorForSupervisedDataset(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):

        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        images = [torch.stack(instance['image']) for instance in instances]


        images_high = [torch.stack(instance['image_high']) for instance in instances]

        images = list(zip(images, images_high))


        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
            
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_INDEX)
        
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            images=images,
        )
        return batch


def make_supervised_data_module(interleave, with_box, tokenizer, data_args):

    if data_args.conversation_version == 'mpt':
        from vary.data.conversation_dataset_qwen import ConversationDataset
        dataset_cls = ConversationDataset
    elif data_args.conversation_version == 'opt':
        from vary.data.caption_opt import CaptionDataset
        dataset_cls = CaptionDataset

    train_dataset = dataset_cls(
        tokenizer=tokenizer,
        datasets=data_args.datasets,
        multimodal_cfg=dict(
            sep_image_conv_front=data_args.sep_image_conv_front,
            image_token_len=data_args.image_token_len,
            image_aspect_ratio=data_args.image_aspect_ratio,
            use_im_start_end=data_args.use_im_start_end,
            image_processor=data_args.image_processor,
            image_processor_high = data_args.image_processor_high,
            box_limit=data_args.box_limit,
        )
    )
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)