data_collator.py 3.42 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from dataclasses import dataclass, field
from magma.processing_magma import MagmaProcessor
from typing import Dict, Optional, Sequence, List
import transformers 
from data.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    processor: MagmaProcessor

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:       
        input_ids, labels, pixel_values, image_sizes = \
            tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "pixel_values", "image_sizes"))
        
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.processor.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.processor.tokenizer.model_max_length]
        labels = labels[:, :self.processor.tokenizer.model_max_length]

        pixel_values = [torch.cat(pv, dim=0) for pv in pixel_values]
        image_sizes = [torch.cat(isz, dim=0) for isz in image_sizes]
        pixel_values_padded = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0)
        image_sizes_padded = torch.nn.utils.rnn.pad_sequence(image_sizes, batch_first=True, padding_value=0)
        
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.processor.tokenizer.pad_token_id),
            pixel_values=pixel_values_padded,
            image_sizes=image_sizes_padded
        )
        return batch

@dataclass
class DataCollatorForHFDataset(object):
    """Collate hugging face examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        if 'image' in instances[0] and instances[0]['image'] is not None:
            images = [instance['image'] for instance in instances]
            # if all(x is not None and x.shape == images[0].shape for x in images):
            #     batch['images'] = torch.stack(images)
            # else:
            batch['images'] = images

        if 'add_im_loss' in instances[0]:
            batch['add_im_loss'] = True
        if 'max_num_crops' in instances[0]:
            batch['max_num_crops'] = instances[0]['max_num_crops']
        return batch