caption_opt.py 14.6 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

import io
import os
import copy
import json
import logging
import torch
import random

from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib

# 这段代码定义了一个名为 CaptionDataset 的类,它继承自 BaseDataset 类。
# 这个类主要用于处理和存储对话格式的数据集,以便于后续的模型训练。

class CaptionDataset(BaseDataset):
    """Conversation format dataset stage2 fine-tuning."""
    """CaptionDataset 类的 __init__ 方法接收三个参数:datasets、tokenizer 和 multimodal_cfg。
    其中,datasets 是数据集的名称,tokenizer 是一个预训练的分词器,multimodal_cfg 是一个字典,包含了多模态配置信息。"""
    def __init__(self, datasets, tokenizer, multimodal_cfg):
         # 调用父类初始化方法
        super(CaptionDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
        # v0 version format conversation
        """在 __init__ 方法中,首先调用父类的初始化方法,然后设置默认的对话模板为 "default"。
        然后,对每个数据集,读取其注释数据和图像路径,并将这些数据添加到 list_data_dict 和 list_image_path 列表中。
        然后,将这两个列表打乱并重新组合,最后将这两个列表以及一些分词器生成的 token 保存为类的属性。"""
        # 设置默认对话模板
        conversation_lib.default_conversation = conversation_lib.conv_templates["default"]

        # 开始格式化输入数据
        logging.warning("Formatting inputs into conversation type: v0-fixed")
        logging.warning("Loading data...")
        
        # 初始化用于存储数据和图像路径的列表
        list_data_dict = []
        list_image_path = []
        
        # 遍历数据集名称,加载数据
        for name in datasets.split("+"):
            dataset = CONVERSATION_DATA[name] # in vary.utils # 从预定义的数据字典中获取数据集
            # 加载注释数据
            data_path = dataset['annotations']
            data = json.load(open(data_path, "r"))

            list_data_dict.extend(data)
            
            # 获取图像路径,并与注释数据对应
            image_path = dataset['images']
            list_image_path.extend([image_path] * len(data))
            
            # 输出数据加载信息
            logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
        
        # 确保数据和图像路径的数量一致
        assert len(list_data_dict) == len(list_image_path)

        # 输出总数据量
        logging.warning(f"{len(list_data_dict)} conversations in total.")

        # 打乱并重新组合数据和图像路径
        a_new_list = list(zip(list_data_dict, list_image_path))
        random.shuffle(a_new_list)
        list_data_dict_new, list_image_path_new = zip(*a_new_list)
        self.list_data_dict = list_data_dict_new
        self.list_image_path = list_image_path_new

        # 为图像标记生成对应的token ID
        self.im_patch_token, self.im_start_token, self.im_end_token = tokenizer.convert_tokens_to_ids(
            [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
    
    def multimodal_processor(self, sources):
        """
        处理多模态源数据,将图像标记替换为一定长度的图像标记序列。

        :param sources: 包含多个源数据的列表,每个源数据是一个字典列表,其中第一个字典代表图像信息,
                    其他字典代表与图像相关联的文本信息。
        :return: 经过图像标记替换处理后的源数据列表。
        """

        for source in sources:
            # 为每个源数据的图像信息设置默认的图像标记
            source[0]['value'] = DEFAULT_IMAGE_TOKEN
            for sentence in source:
                 # 根据配置,生成替换用的图像标记序列
                replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
                if self.multimodal_cfg['use_im_start_end']:
                    # 如果使用图像开始和结束标记,则在图像标记序列两端添加开始和结束标记
                    replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
                 # 使用生成的图像标记序列替换句子中的默认图像标记
                sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
        return sources

    def _tokenize_fn(self, strings):
        """Tokenize a list of strings."""
        """
        对字符串列表进行分词处理,以供模型输入。

        此函数处理一个字符串列表,对每个字符串进行分词,填充或截断令牌到最大长度,并创建输入掩码。
        它还处理特定的分词情况,如确保每个输入的最后一个令牌是特定令牌。

        参数:
        - strings (list): 需要分词的字符串列表。

        返回:
        - dict: 包含以下键值对的字典:
          - input_ids (list): 表示每个输入字符串的令牌ID列表。
          - labels (list): 用作标签的令牌ID列表,通常与input_ids相同。
          - input_ids_lens (list): 分词后的字符串长度列表。
          - labels_lens (list): 标签长度列表。
        """

        # 对字符串列表中的每个字符串进行分词并准备模型输入 
        tokenized_list = [
            self.tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
            ) for text in strings
        ]
        # 从分词结果中提取输入ID和标签
        input_ids = labels = [
            tokenized.input_ids[0] for tokenized in tokenized_list
        ]

        # 确保每个输入的最后一个令牌是特定令牌(此处为2)
        for idx, ii in enumerate(input_ids):
            if ii[-1] != 2:
                input_ids[idx][-1] = 2
                labels[idx][-1] = 2

        # 计算每个输入的非填充令牌长度
        input_ids_lens = labels_lens = [
            tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
            for tokenized in tokenized_list
        ]

        # 返回包含所有必需处理数据的字典
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    def _mask_targets(self, target, tokenized_lens, speakers):

        """
        隐藏目标序列中特定部分的数据。
        
        参数:
        - target: 待处理的目标序列,通常是一个序列化的文本或音频表示。
        - tokenized_lens: 分词后每个序列的长度列表。
        - speakers: 对应于每个序列的说话者信息列表。
        
        返回值:
        - 修改后的target,其中根据说话者信息和长度信息进行了掩码处理。
        """
        # cur_idx = 0
        # 初始化当前索引为第一个序列的长度
        cur_idx = tokenized_lens[0]
        tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度
        target[:cur_idx] = IGNORE_INDEX   # 隐藏目标序列的起始部分

        # 遍历剩余的序列长度和说话者信息,进行相应的掩码处理
        for tokenized_len, speaker in zip(tokenized_lens, speakers):
            if speaker.lower() == "human":
                target[cur_idx:tokenized_len] = IGNORE_INDEX # 如果说话者是“human”,隐藏该部分目标序列
            cur_idx += tokenized_len # 更新当前索引位置,准备处理下一个序列


    def _add_speaker_and_signal(self, header, source, get_conversation=True):

        """Add speaker and start/end signal on each round."""
        """
        在每个回合中添加说话者和开始/结束信号。

        此方法处理一个句子列表,识别每个句子的说话者,并向句子添加开始和结束信号。
        如果`get_conversation`为真,它将所有句子连接成一个完整的对话字符串。

        参数:
        - header (str): 对话的初始消息或标题。
        - source (list): 每个句子是一个字典,至少包含'from'键和'value'键。
                        'from'键指定说话者,'value'键指定所说的话。
        - get_conversation (bool, 可选): 标志,表示是否将所有句子合并为一个单一的对话字符串。
                                        默认为True。

        返回:
        - str: 添加了说话者并带有适当开始和结束信号的最终对话字符串。
            如果`get_conversation`为False,则返回None。
        """

        BEGIN_SIGNAL = "</s>"
        END_SIGNAL = "\n"
        conversation = header    # 初始化对话,附带标题

        # 处理源列表中的每个句子
        for sentence in source:
            from_str = sentence["from"] # 获取说话者 获取说话信息
            # 将说话者名称映射到预定义的角色
            if from_str.lower() == "human":
                from_str = conversation_lib.default_conversation.roles[0]
            else:
                from_str = conversation_lib.default_conversation.roles[1]

            sentence["value"] = sentence["value"] + END_SIGNAL # 向句子添加结束信号
            if get_conversation:
                conversation += sentence["value"] # 若需要,将句子拼接到对话中
        conversation += BEGIN_SIGNAL 
        return conversation

    def token_processor(self, sources):
        """
        Given a list of sources, each is a conversation list. This transform:
        1. Add signal '### ' at the beginning each sentence, with end signal '\n';
        2. Concatenate conversations together;
        3. Tokenize the concatenated conversation;
        4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.

        multimodal_processor 方法是用来处理多模态数据的。这个方法接收一个名为 sources 的参数,这个参数应该是一个包含多个数据源的列表。
        每个数据源都应该是一个字典,包含 "value" 和 "sentence" 这两个键。
        这个方法会遍历每个数据源,将 "value" 键的值替换为特定的 token,然后返回处理后的数据源列表。
        """

        """
        处理输入的对话列表,进行标记化和掩码处理。
        
        参数:
        sources: 一个包含多个会话列表的列表,每个会话列表是一个字典列表,其中每个字典包含"value"和"sentence"键。
        
        返回:
        一个字典,包含输入ID和标签(被掩码处理的人类语言)的列表。
        """
        # add end signal and concatenate together
        # 为每个会话添加信号并在末尾连接
        conversations = []
        for source in sources:
            header = ''
            conversation = self._add_speaker_and_signal(header, source) # 添加说话者和信号到会话
            conversations.append(conversation)
        # 对会话进行标记化
        conversations_tokenized = self._tokenize_fn(conversations) # 标记化会话文本
        input_ids = conversations_tokenized["input_ids"]  # 获取输入ID

        # 创建目标的深拷贝,并掩码处理人类语言单词
        targets = copy.deepcopy(input_ids) # 创建输入ID的深拷贝作为目标
        for target, source in zip(targets, sources):  # 遍历目标和源,进行掩码处理
            tokenized_lens = self._tokenize_fn([header] + [s["value"] for s in source])["input_ids_lens"] # 计算每个说话者的标记长度
            speakers = [sentence["from"] for sentence in source] # 获取每个句子的说话者
            self._mask_targets(target, tokenized_lens, speakers) # 根据说话者信息掩码目标序列

        return dict(input_ids=input_ids, labels=targets) # 返回处理后的输入和目标

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        # data = self.list_data_dict[i]
        """
        根据索引i获取数据集中的一个项目。
        
        参数:
        - i: int, 数据集中项的索引。
        
        返回值:
        - Dict[str, torch.Tensor]: 包含项目数据的字典,其中可能包括'image', 'image_high', 'input_ids', 'labels'键。
        """
        # 深拷贝数据以避免修改原始数据
        data = copy.deepcopy(self.list_data_dict[i])
        # 处理包含图像数据的项
        if isinstance(data, dict):
            if 'image' in data:
                image_path = self.list_image_path[i]
                image_file = data['image']
                # TODO this is a bug, because some json has wrong path
                # 尝试打开并处理图像
                try:
                    image = Image.open(image_path + image_file).convert('RGB')
                except:
                    print(f'cannot identify image file {image_path+image_file}.')
                    return self.__getitem__(0)
                # 尝试对图像进行进一步处理
                try:
                    image, image_high = self.image_processor(image)
                except:
                    print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
                    return self.__getitem__(0)

            conversations = self.multimodal_processor([data["conversations"]])
        else:
            # 对于不包含图像数据的项
            conversations = [data]

        # align with fastchat & llava here, put the conversation into a list for tokenization
        # 对会话数据进行标记化处理
        data_dict = self.token_processor(conversations)
        # 简化数据字典结构
        data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
        
        # 如果数据包含图像,添加图像数据到字典中
        if isinstance(data, dict) and 'image' in data:
            data_dict['image'] = [image]
            data_dict['image_high'] = [image_high]
        else:
            # 为不包含图像的数据项填充默认的图像张量
            crop_size = self.multimodal_cfg['image_processor'].crop_size
            data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
            # TODO sam is 1024*1024
            data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
        return data_dict