conversation_dataset_qwen.py 14.9 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

import io
import os
import copy
import json
import logging
import torch
import random

from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib

class ConversationDataset(BaseDataset):
    """Conversation format dataset stage2 fine-tuning."""
    """
    这段代码是一个名为ConversationDataset的Python类,它继承自BaseDataset类。
    这个类主要用于处理对话格式的数据集,特别是在多模态配置(multimodal_cfg)的情况下。

    这个方法中,首先调用了父类的构造函数来初始化一些基本的属性。然后,它设置了默认的对话模板,并打印了一些日志信息。
    接着,它遍历datasets,加载数据和图片路径,并将它们添加到列表中。最后,它将数据和图片路径的列表打乱,并将它们赋值给类的属性。
    """

    def __init__(self, datasets, tokenizer, multimodal_cfg):
        # 调用父类的构造函数
        super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
        # v0 version format conversation  
        # 初始化会话格式为mpt
        conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]     
        # 输出警告信息,说明正在格式化输入为mpt-fixed类型的会话
        logging.warning("Formatting inputs into conversation type: mpt-fixed")
        logging.warning("Loading data...")
        
        # 初始化数据字典列表和图片路径列表
        list_data_dict = []
        list_image_path = []

        # 遍历数据集名称,加载每个数据集
        for name in datasets.split("+"):
            # for name in vary_data_dict[name_all]:
            dataset = CONVERSATION_DATA[name]

            data_path = dataset['annotations']
            data = json.load(open(data_path, "r")) # 加载注释数据

            list_data_dict.extend(data) # 将数据添加到列表中

            image_path = dataset['images']

            list_image_path.extend([image_path] * len(data))  # 为每个数据项添加图片路径
            # 输出每个数据集提供的会话数量
            logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
        # 确保数据项和图片路径数量一致
        assert len(list_data_dict) == len(list_image_path)
        # 输出总共有多少会话
        logging.warning(f"{len(list_data_dict)} conversations in total.")
        # 打乱数据和图片路径的组合,并重新分配
        a_new_list = list(zip(list_data_dict, list_image_path))
        random.shuffle(a_new_list)
        list_data_dict_new, list_image_path_new = zip(*a_new_list)
        self.list_data_dict = list_data_dict_new
        self.list_image_path = list_image_path_new


        # 设置图像标记相关的令牌
        self.im_patch_token = 151859
        self.im_start_token = 151857
        self.im_end_token = 151858
    
    def multimodal_processor(self, sources):
        """
        这段代码是一个名为`ConversationDataset`的Python类,它继承自`BaseDataset`类。
        这个类主要用于处理对话格式的数据集,特别是在多模态配置(`multimodal_cfg`)的情况下。

        multimodal_processor方法:这个方法用于处理多模态的源数据。它遍历sources,并检查每个源数据的第一个元素的'value'是否包含默认的图片标记。
        如果不包含,它会抛出一个断言错误。
        """

        """
        处理多模态源数据的方法。

        此方法遍历输入的源数据列表,对每条源数据进行处理。首先,检查是否在多模态配置中指定了在对话前添加图片标记(`sep_image_conv_front`),
        如果指定且第一个元素的'value'中不包含默认的图片标记,则抛出断言错误。然后,对于每个源数据中的句子,
        将默认图片标记替换为一定长度的图片标记字符串,如果在多模态配置中启用了使用图片开始和结束标记(`use_im_start_end`),
        则在图片标记字符串前后添加图片开始和结束标记。

        参数:
        - sources: 包含多模态源数据的列表。每个源数据是一个列表,其中第一个元素是特殊的源标记,后续元素是具体的句子信息。

        返回值:
        - 处理后的多模态源数据列表。
        """
        for source in sources:
             # 检查是否在对话前添加图片标记,并处理图片标记
            if self.multimodal_cfg['sep_image_conv_front']:
                assert DEFAULT_IMAGE_TOKEN in source[0]['value']
                source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
                source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
            # 重新格式化源标记,添加角色信息
            for sentence in source:
                # 准备替换的图片标记字符串
                replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
                # if self.multimodal_cfg['use_im_start_end']:
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
                # 如果启用了使用图片开始和结束标记,格式化图片标记字符串
                sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
        return sources

    def _tokenize_fn(self, strings):
        """Tokenize a list of strings."""

        """
        对字符串列表进行分词处理。

        此函数接受一个字符串列表,并使用实例关联的分词器对其进行分词。
        返回一个字典,其中包含分词后的输入、标签以及它们的长度。

        参数:
        - strings (list): 需要分词的字符串列表。

        返回:
        - dict: 包含以下键值对的字典:
          - input_ids (list): 分词后输入ID的列表。
          - labels (list): 分词后标签ID的列表(如果适用,否则与input_ids相同)。
          - input_ids_lens (list): 分词后输入的长度列表。
          - labels_lens (list): 分词后标签的长度列表。
        """
        # 对列表中的每个字符串进行分词并存储结果
        tokenized_list = [
            self.tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
            ) for text in strings
        ]

        # 从每个分词后的字符串中提取input_ids
        input_ids = labels = [
            tokenized.input_ids[0] for tokenized in tokenized_list
        ]
         # 计算每个分词字符串的长度(不包括填充令牌)
        input_ids_lens = labels_lens = [
            tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
            for tokenized in tokenized_list
        ]
        # 返回包含所有相关分词数据的字典
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    def _mask_targets(self, target, tokenized_lens, speakers):
        """
        隐藏目标序列中特定部分的数据。
        
        参数:
        - target: 待处理的目标序列,通常是一个序列化的文本或音频表示。
        - tokenized_lens: 分词后每个序列的长度列表。
        - speakers: 对应于每个序列的说话者信息列表。
        
        返回值:
        - 修改后的target,其中特定部分被隐藏。
        """

        # cur_idx = 0
        # 初始化当前索引为第一个序列的长度
        cur_idx = tokenized_lens[0]
        tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度
        target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分 

        # 遍历剩余的序列长度和说话者信息,隐藏相应部分的目标序列
        for tokenized_len, speaker in zip(tokenized_lens, speakers):
            if speaker.lower() == "human":
                target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX # 隐藏human说话者的部分,偏移量+2可能用于标识说话者
            cur_idx += tokenized_len # 更新当前索引以处理下一个序列

    def token_processor(self, sources):
        """
        处理对话源数据,将其转换为适用于模型输入的格式,包括标记化和目标掩码处理。

        参数:
        sources - 一个包含多个对话来源的列表,每个对话来源自身是一个列表,其中包含了多个句子字典,
                每个字典包含"from"和"value"键,分别指明句子的来源("human"或"gpt")和句子内容。

        返回值:
        一个字典,包含两个键:
        - input_ids: 经过标记化处理后的输入序列ID,格式为PyTorch张量。
        - labels: 对应的掩码处理后的目标序列ID,用于模型训练,格式同input_ids。
        """

        # 复制默认对话对象并定义角色
        conv = conversation_lib.default_conversation.copy()
        roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

        # Apply prompt templates     # 应用对话模板
        conversations = []
        for i, source in enumerate(sources):
              # 如果首个消息不是来自人类,则跳过
            if roles[source[0]["from"]] != conv.roles[0]:
                # Skip the first one if it is not from human
                source = source[1:]

        # 初始化对话消息
            conv.messages = []
            for j, sentence in enumerate(source):
                role = roles[sentence["from"]]
                # 确保角色匹配预期顺序
                assert role == conv.roles[j % 2], f"{i}"
                conv.append_message(role, sentence["value"])
            conversations.append(conv.get_prompt())

        # Tokenize conversations

        # 对话内容标记化
        input_ids = self.tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
        ).input_ids

        # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
        targets = input_ids.clone() # 确保使用的分隔符风格匹配预期
        assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

        #  目标掩码处理
        sep = conv.sep + conv.roles[1]
        for conversation, target in zip(conversations, targets):
            # 计算非填充token的总数
            total_len = int(target.ne(self.tokenizer.pad_token_id).sum())

            rounds = conversation.split(conv.sep)
            re_rounds = [conv.sep.join(rounds[:3])] # 系统 + 用户 + gpt
            for conv_idx in range(3, len(rounds), 2):
                re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))  # 用户 + gpt
            cur_len = 0
            target[:cur_len] = IGNORE_INDEX
            for i, rou in enumerate(re_rounds):
                if rou == "":
                    break

                parts = rou.split(sep)
                if len(parts) != 2:
                    break
                parts[0] += sep
                round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)

                instruction_len = len(self.tokenizer(parts[0]).input_ids)
                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

                cur_len += round_len
            target[cur_len:] = IGNORE_INDEX


            # 警告处理:标记化结果与预期长度不匹配
            if cur_len < self.tokenizer.model_max_length:
                if cur_len != total_len:
                    target[:] = IGNORE_INDEX
                    print(
                        f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                        f" (ignored)"
                    )

        return dict(
            input_ids=input_ids,
            labels=targets,
        )

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        # data = self.list_data_dict[i]

        """
        根据索引i获取数据项,并进行预处理。如果数据项包含图像,则同时进行图像预处理。
        参数:
            i: 索引值,用于指定要获取的数据项。
        返回值:
            一个字典,包含经过处理的文本和图像数据,键包括input_ids, labels, image, image_high等。
        """
        # 深拷贝数据项以避免修改原始数据
        data = copy.deepcopy(self.list_data_dict[i])

        # 检查数据项是否为字典,并处理其中的图像数据
        if isinstance(data, dict):
            if 'image' in data:
                image_path = self.list_image_path[i]
                image_file = data['image']
                # 尝试打开并转换图像
                try:
                    image = Image.open(image_path + image_file).convert('RGB')
                except:
                    print(f'cannot identify image file {image_path + image_file}.')
                    return self.__getitem__(0)
                # 尝试打开并转换图像
                try:
                    image, image_1 = self.image_processor(image)
                except:
                    print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
                    return self.__getitem__(0)
            # 对对话数据进行预处理
            conversations = self.multimodal_processor([data["conversations"]])

        else:
            # 如果数据项不是字典,则直接将其作为对话处理
            conversations = [data]
        # 使用token处理器处理对话数据
        # align with fastchat & llava here, put the conversation into a list for tokenization
        data_dict = self.token_processor(conversations)
        # 精简处理后的数据字典,只保留input_ids和labels
        data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
        # 如果数据项包含图像,将图像数据添加到数据字典中
        if isinstance(data, dict) and 'image' in data:
            data_dict['image'] = [image]
            data_dict['image_high'] = [image_1]
        else:
            # 如果数据项不包含图像,为图像数据添加占位符
            crop_size = self.multimodal_cfg['image_processor'].crop_size
            data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
            data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
        return data_dict