"magic_pdf/vscode:/vscode.git/clone" did not exist on "d244a1c1a7de8a450b07f88ba7483b93a7f786ab"
Commit 0bc22e1d authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
import io
import os
import copy
import json
import logging
import torch
import transformers
from typing import List, Optional, Tuple, Union, Dict, Sequence
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.utils.constants import *
"""
这段代码定义了一个名为 BaseDataset 的基础数据集类,它继承自 PyTorch 的 Dataset 类。这个类主要用于处理和存储数据,以便于后续的模型训练。
"""
class BaseDataset(Dataset):
def __init__(
self,
datasets: str, # 是数据集的路径
tokenizer: transformers.PreTrainedTokenizer, # 用于分词的tokenizer
multimodal_cfg: dict # 多模态配置字典,包含不同模态的配置信息
):
"""
初始化BaseDataset类的实例。
参数:
- datasets: str,数据集的路径。
- tokenizer: transformers.PreTrainedTokenizer,用于对文本进行分词的预训练tokenizer。
- multimodal_cfg: dict,包含多种模态配置的字典,例如图像和文本的配置。
"""
super(BaseDataset, self).__init__()
self.tokenizer = tokenizer
self.multimodal_cfg = multimodal_cfg
# 记录使用多少个tokens来表示图像
logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")
def image_processor(self, image):
"""
image_processor 方法是用来处理图像的。这个方法接收一个图像作为输入,然后根据 multimodal_cfg 中的配置信息来处理这个图像。
首先,获取两个图像处理器:processor 和 processor_high。然后,根据 multimodal_cfg 中的 image_aspect_ratio 配置来决定如何处理图像。
如果 image_aspect_ratio 是 "keep",那么会保持图像的长宽比;如果 image_aspect_ratio 是 "pad",那么会将图像扩展到正方形;
如果 image_aspect_ratio 是其他值,那么会直接处理图像。最后,使用 processor_high 来处理复制的图像,并返回处理后的两个图像。
"""
"""
处理图像,通过两个处理器对图像进行处理:第一个处理器通常为预训练的模型(如CLIP、VIT),第二个处理器通常为设计的图像编码器(如SAM、SWIN、CNN)。
参数:
- image: 图像对象,待处理的图像。
返回值:
- image: 经过第一个处理器处理后的图像张量。
- image_high: 经过第二个处理器处理后的图像张量。
"""
processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit)
processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn)
image_high = image.copy()
# TODO the 'keep', 'padding' only used for the first processor
# 根据配置决定如何处理图像以保持纵横比或填充到固定大小
if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
# 保持纵横比处理图像
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
# 将图像转换为张量
image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
# 通过填充方式处理图像,保持为正方形
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img) # for simpler box processing
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img) # for simpler box processing
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
# 使用第二个处理器处理图像
image_high = processor_high(image_high)
return image, image_high
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
pass
\ No newline at end of file
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib
# 这段代码定义了一个名为 CaptionDataset 的类,它继承自 BaseDataset 类。
# 这个类主要用于处理和存储对话格式的数据集,以便于后续的模型训练。
class CaptionDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
"""CaptionDataset 类的 __init__ 方法接收三个参数:datasets、tokenizer 和 multimodal_cfg。
其中,datasets 是数据集的名称,tokenizer 是一个预训练的分词器,multimodal_cfg 是一个字典,包含了多模态配置信息。"""
def __init__(self, datasets, tokenizer, multimodal_cfg):
# 调用父类初始化方法
super(CaptionDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
"""在 __init__ 方法中,首先调用父类的初始化方法,然后设置默认的对话模板为 "default"。
然后,对每个数据集,读取其注释数据和图像路径,并将这些数据添加到 list_data_dict 和 list_image_path 列表中。
然后,将这两个列表打乱并重新组合,最后将这两个列表以及一些分词器生成的 token 保存为类的属性。"""
# 设置默认对话模板
conversation_lib.default_conversation = conversation_lib.conv_templates["default"]
# 开始格式化输入数据
logging.warning("Formatting inputs into conversation type: v0-fixed")
logging.warning("Loading data...")
# 初始化用于存储数据和图像路径的列表
list_data_dict = []
list_image_path = []
# 遍历数据集名称,加载数据
for name in datasets.split("+"):
dataset = CONVERSATION_DATA[name] # in vary.utils # 从预定义的数据字典中获取数据集
# 加载注释数据
data_path = dataset['annotations']
data = json.load(open(data_path, "r"))
list_data_dict.extend(data)
# 获取图像路径,并与注释数据对应
image_path = dataset['images']
list_image_path.extend([image_path] * len(data))
# 输出数据加载信息
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
# 确保数据和图像路径的数量一致
assert len(list_data_dict) == len(list_image_path)
# 输出总数据量
logging.warning(f"{len(list_data_dict)} conversations in total.")
# 打乱并重新组合数据和图像路径
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
# 为图像标记生成对应的token ID
self.im_patch_token, self.im_start_token, self.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
def multimodal_processor(self, sources):
"""
处理多模态源数据,将图像标记替换为一定长度的图像标记序列。
:param sources: 包含多个源数据的列表,每个源数据是一个字典列表,其中第一个字典代表图像信息,
其他字典代表与图像相关联的文本信息。
:return: 经过图像标记替换处理后的源数据列表。
"""
for source in sources:
# 为每个源数据的图像信息设置默认的图像标记
source[0]['value'] = DEFAULT_IMAGE_TOKEN
for sentence in source:
# 根据配置,生成替换用的图像标记序列
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
if self.multimodal_cfg['use_im_start_end']:
# 如果使用图像开始和结束标记,则在图像标记序列两端添加开始和结束标记
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# 使用生成的图像标记序列替换句子中的默认图像标记
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
"""
对字符串列表进行分词处理,以供模型输入。
此函数处理一个字符串列表,对每个字符串进行分词,填充或截断令牌到最大长度,并创建输入掩码。
它还处理特定的分词情况,如确保每个输入的最后一个令牌是特定令牌。
参数:
- strings (list): 需要分词的字符串列表。
返回:
- dict: 包含以下键值对的字典:
- input_ids (list): 表示每个输入字符串的令牌ID列表。
- labels (list): 用作标签的令牌ID列表,通常与input_ids相同。
- input_ids_lens (list): 分词后的字符串长度列表。
- labels_lens (list): 标签长度列表。
"""
# 对字符串列表中的每个字符串进行分词并准备模型输入
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
# 从分词结果中提取输入ID和标签
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
# 确保每个输入的最后一个令牌是特定令牌(此处为2)
for idx, ii in enumerate(input_ids):
if ii[-1] != 2:
input_ids[idx][-1] = 2
labels[idx][-1] = 2
# 计算每个输入的非填充令牌长度
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
# 返回包含所有必需处理数据的字典
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
"""
隐藏目标序列中特定部分的数据。
参数:
- target: 待处理的目标序列,通常是一个序列化的文本或音频表示。
- tokenized_lens: 分词后每个序列的长度列表。
- speakers: 对应于每个序列的说话者信息列表。
返回值:
- 修改后的target,其中根据说话者信息和长度信息进行了掩码处理。
"""
# cur_idx = 0
# 初始化当前索引为第一个序列的长度
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度
target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分
# 遍历剩余的序列长度和说话者信息,进行相应的掩码处理
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx:tokenized_len] = IGNORE_INDEX # 如果说话者是“human”,隐藏该部分目标序列
cur_idx += tokenized_len # 更新当前索引位置,准备处理下一个序列
def _add_speaker_and_signal(self, header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
"""
在每个回合中添加说话者和开始/结束信号。
此方法处理一个句子列表,识别每个句子的说话者,并向句子添加开始和结束信号。
如果`get_conversation`为真,它将所有句子连接成一个完整的对话字符串。
参数:
- header (str): 对话的初始消息或标题。
- source (list): 每个句子是一个字典,至少包含'from'键和'value'键。
'from'键指定说话者,'value'键指定所说的话。
- get_conversation (bool, 可选): 标志,表示是否将所有句子合并为一个单一的对话字符串。
默认为True。
返回:
- str: 添加了说话者并带有适当开始和结束信号的最终对话字符串。
如果`get_conversation`为False,则返回None。
"""
BEGIN_SIGNAL = "</s>"
END_SIGNAL = "\n"
conversation = header # 初始化对话,附带标题
# 处理源列表中的每个句子
for sentence in source:
from_str = sentence["from"] # 获取说话者 获取说话信息
# 将说话者名称映射到预定义的角色
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
else:
from_str = conversation_lib.default_conversation.roles[1]
sentence["value"] = sentence["value"] + END_SIGNAL # 向句子添加结束信号
if get_conversation:
conversation += sentence["value"] # 若需要,将句子拼接到对话中
conversation += BEGIN_SIGNAL
return conversation
def token_processor(self, sources):
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
multimodal_processor 方法是用来处理多模态数据的。这个方法接收一个名为 sources 的参数,这个参数应该是一个包含多个数据源的列表。
每个数据源都应该是一个字典,包含 "value" 和 "sentence" 这两个键。
这个方法会遍历每个数据源,将 "value" 键的值替换为特定的 token,然后返回处理后的数据源列表。
"""
"""
处理输入的对话列表,进行标记化和掩码处理。
参数:
sources: 一个包含多个会话列表的列表,每个会话列表是一个字典列表,其中每个字典包含"value"和"sentence"键。
返回:
一个字典,包含输入ID和标签(被掩码处理的人类语言)的列表。
"""
# add end signal and concatenate together
# 为每个会话添加信号并在末尾连接
conversations = []
for source in sources:
header = ''
conversation = self._add_speaker_and_signal(header, source) # 添加说话者和信号到会话
conversations.append(conversation)
# 对会话进行标记化
conversations_tokenized = self._tokenize_fn(conversations) # 标记化会话文本
input_ids = conversations_tokenized["input_ids"] # 获取输入ID
# 创建目标的深拷贝,并掩码处理人类语言单词
targets = copy.deepcopy(input_ids) # 创建输入ID的深拷贝作为目标
for target, source in zip(targets, sources): # 遍历目标和源,进行掩码处理
tokenized_lens = self._tokenize_fn([header] + [s["value"] for s in source])["input_ids_lens"] # 计算每个说话者的标记长度
speakers = [sentence["from"] for sentence in source] # 获取每个句子的说话者
self._mask_targets(target, tokenized_lens, speakers) # 根据说话者信息掩码目标序列
return dict(input_ids=input_ids, labels=targets) # 返回处理后的输入和目标
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
"""
根据索引i获取数据集中的一个项目。
参数:
- i: int, 数据集中项的索引。
返回值:
- Dict[str, torch.Tensor]: 包含项目数据的字典,其中可能包括'image', 'image_high', 'input_ids', 'labels'键。
"""
# 深拷贝数据以避免修改原始数据
data = copy.deepcopy(self.list_data_dict[i])
# 处理包含图像数据的项
if isinstance(data, dict):
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
# TODO this is a bug, because some json has wrong path
# 尝试打开并处理图像
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_path+image_file}.')
return self.__getitem__(0)
# 尝试对图像进行进一步处理
try:
image, image_high = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
conversations = self.multimodal_processor([data["conversations"]])
else:
# 对于不包含图像数据的项
conversations = [data]
# align with fastchat & llava here, put the conversation into a list for tokenization
# 对会话数据进行标记化处理
data_dict = self.token_processor(conversations)
# 简化数据字典结构
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# 如果数据包含图像,添加图像数据到字典中
if isinstance(data, dict) and 'image' in data:
data_dict['image'] = [image]
data_dict['image_high'] = [image_high]
else:
# 为不包含图像的数据项填充默认的图像张量
crop_size = self.multimodal_cfg['image_processor'].crop_size
data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
# TODO sam is 1024*1024
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib
class ConversationDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
"""
这段代码是一个名为ConversationDataset的Python类,它继承自BaseDataset类。
这个类主要用于处理对话格式的数据集,特别是在多模态配置(multimodal_cfg)的情况下。
这个方法中,首先调用了父类的构造函数来初始化一些基本的属性。然后,它设置了默认的对话模板,并打印了一些日志信息。
接着,它遍历datasets,加载数据和图片路径,并将它们添加到列表中。最后,它将数据和图片路径的列表打乱,并将它们赋值给类的属性。
"""
def __init__(self, datasets, tokenizer, multimodal_cfg):
# 调用父类的构造函数
super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
# 初始化会话格式为mpt
conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
# 输出警告信息,说明正在格式化输入为mpt-fixed类型的会话
logging.warning("Formatting inputs into conversation type: mpt-fixed")
logging.warning("Loading data...")
# 初始化数据字典列表和图片路径列表
list_data_dict = []
list_image_path = []
# 遍历数据集名称,加载每个数据集
for name in datasets.split("+"):
# for name in vary_data_dict[name_all]:
dataset = CONVERSATION_DATA[name]
data_path = dataset['annotations']
data = json.load(open(data_path, "r")) # 加载注释数据
list_data_dict.extend(data) # 将数据添加到列表中
image_path = dataset['images']
list_image_path.extend([image_path] * len(data)) # 为每个数据项添加图片路径
# 输出每个数据集提供的会话数量
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
# 确保数据项和图片路径数量一致
assert len(list_data_dict) == len(list_image_path)
# 输出总共有多少会话
logging.warning(f"{len(list_data_dict)} conversations in total.")
# 打乱数据和图片路径的组合,并重新分配
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
# 设置图像标记相关的令牌
self.im_patch_token = 151859
self.im_start_token = 151857
self.im_end_token = 151858
def multimodal_processor(self, sources):
"""
这段代码是一个名为`ConversationDataset`的Python类,它继承自`BaseDataset`类。
这个类主要用于处理对话格式的数据集,特别是在多模态配置(`multimodal_cfg`)的情况下。
multimodal_processor方法:这个方法用于处理多模态的源数据。它遍历sources,并检查每个源数据的第一个元素的'value'是否包含默认的图片标记。
如果不包含,它会抛出一个断言错误。
"""
"""
处理多模态源数据的方法。
此方法遍历输入的源数据列表,对每条源数据进行处理。首先,检查是否在多模态配置中指定了在对话前添加图片标记(`sep_image_conv_front`),
如果指定且第一个元素的'value'中不包含默认的图片标记,则抛出断言错误。然后,对于每个源数据中的句子,
将默认图片标记替换为一定长度的图片标记字符串,如果在多模态配置中启用了使用图片开始和结束标记(`use_im_start_end`),
则在图片标记字符串前后添加图片开始和结束标记。
参数:
- sources: 包含多模态源数据的列表。每个源数据是一个列表,其中第一个元素是特殊的源标记,后续元素是具体的句子信息。
返回值:
- 处理后的多模态源数据列表。
"""
for source in sources:
# 检查是否在对话前添加图片标记,并处理图片标记
if self.multimodal_cfg['sep_image_conv_front']:
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
# 重新格式化源标记,添加角色信息
for sentence in source:
# 准备替换的图片标记字符串
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
# if self.multimodal_cfg['use_im_start_end']:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# 如果启用了使用图片开始和结束标记,格式化图片标记字符串
sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
"""
对字符串列表进行分词处理。
此函数接受一个字符串列表,并使用实例关联的分词器对其进行分词。
返回一个字典,其中包含分词后的输入、标签以及它们的长度。
参数:
- strings (list): 需要分词的字符串列表。
返回:
- dict: 包含以下键值对的字典:
- input_ids (list): 分词后输入ID的列表。
- labels (list): 分词后标签ID的列表(如果适用,否则与input_ids相同)。
- input_ids_lens (list): 分词后输入的长度列表。
- labels_lens (list): 分词后标签的长度列表。
"""
# 对列表中的每个字符串进行分词并存储结果
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
# 从每个分词后的字符串中提取input_ids
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
# 计算每个分词字符串的长度(不包括填充令牌)
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
# 返回包含所有相关分词数据的字典
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
"""
隐藏目标序列中特定部分的数据。
参数:
- target: 待处理的目标序列,通常是一个序列化的文本或音频表示。
- tokenized_lens: 分词后每个序列的长度列表。
- speakers: 对应于每个序列的说话者信息列表。
返回值:
- 修改后的target,其中特定部分被隐藏。
"""
# cur_idx = 0
# 初始化当前索引为第一个序列的长度
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度
target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分
# 遍历剩余的序列长度和说话者信息,隐藏相应部分的目标序列
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX # 隐藏human说话者的部分,偏移量+2可能用于标识说话者
cur_idx += tokenized_len # 更新当前索引以处理下一个序列
def token_processor(self, sources):
"""
处理对话源数据,将其转换为适用于模型输入的格式,包括标记化和目标掩码处理。
参数:
sources - 一个包含多个对话来源的列表,每个对话来源自身是一个列表,其中包含了多个句子字典,
每个字典包含"from"和"value"键,分别指明句子的来源("human"或"gpt")和句子内容。
返回值:
一个字典,包含两个键:
- input_ids: 经过标记化处理后的输入序列ID,格式为PyTorch张量。
- labels: 对应的掩码处理后的目标序列ID,用于模型训练,格式同input_ids。
"""
# 复制默认对话对象并定义角色
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates # 应用对话模板
conversations = []
for i, source in enumerate(sources):
# 如果首个消息不是来自人类,则跳过
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
# 初始化对话消息
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
# 确保角色匹配预期顺序
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
# 对话内容标记化
input_ids = self.tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
).input_ids
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
targets = input_ids.clone() # 确保使用的分隔符风格匹配预期
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# 目标掩码处理
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
# 计算非填充token的总数
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # 系统 + 用户 + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # 用户 + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
instruction_len = len(self.tokenizer(parts[0]).input_ids)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
# 警告处理:标记化结果与预期长度不匹配
if cur_len < self.tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
"""
根据索引i获取数据项,并进行预处理。如果数据项包含图像,则同时进行图像预处理。
参数:
i: 索引值,用于指定要获取的数据项。
返回值:
一个字典,包含经过处理的文本和图像数据,键包括input_ids, labels, image, image_high等。
"""
# 深拷贝数据项以避免修改原始数据
data = copy.deepcopy(self.list_data_dict[i])
# 检查数据项是否为字典,并处理其中的图像数据
if isinstance(data, dict):
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
# 尝试打开并转换图像
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_path + image_file}.')
return self.__getitem__(0)
# 尝试打开并转换图像
try:
image, image_1 = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
# 对对话数据进行预处理
conversations = self.multimodal_processor([data["conversations"]])
else:
# 如果数据项不是字典,则直接将其作为对话处理
conversations = [data]
# 使用token处理器处理对话数据
# align with fastchat & llava here, put the conversation into a list for tokenization
data_dict = self.token_processor(conversations)
# 精简处理后的数据字典,只保留input_ids和labels
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# 如果数据项包含图像,将图像数据添加到数据字典中
if isinstance(data, dict) and 'image' in data:
data_dict['image'] = [image]
data_dict['image_high'] = [image_1]
else:
# 如果数据项不包含图像,为图像数据添加占位符
crop_size = self.multimodal_cfg['image_processor'].crop_size
data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
# 这段代码是用于加载和评估一个预训练模型的。它使用了Hugging Face的Transformers库,以及PIL库来处理图像。
def load_image(image_file):
"""
加载图像。
参数:
- image_file: 图像的文件路径或URL。
返回值:
- image: 打开的图像,转换为RGB模式。
"""
# 判断图像文件是否是URL
if image_file.startswith('http') or image_file.startswith('https'):
# 从URL下载图像
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
# 从本地文件系统打开图像
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
"""
评估模型性能。
参数:
- args: 一个包含模型名称和图像文件路径等信息的参数对象。
说明:
此函数加载指定的模型和对应的tokenizer,将输入的图像转换为文本描述。
使用关键词停止准则以防止生成不必要的文本。
"""
# Model
# 初始化模型和tokenizer
disable_torch_init() # 禁用PyTorch的默认初始化
model_name = os.path.expanduser(args.model_name) # 解析模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # 解析模型名称
model = varyOPTForCausalLM.from_pretrained(model_name) # 从预训练加载模型
# 将模型转移到GPU并设置数据类型
model.to(device='cuda', dtype=torch.bfloat16)
# 设置图像处理和token长度
image_processor_high = test_transform # 定义图像预处理
image_token_len = 256 # 图像token长度
# 构建输入prompt
prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
inputs = tokenizer([prompt]) # 对prompt进行tokenization
# 加载输入图像
image = load_image(args.image_file)
image_1 = image.copy()
# 加载输入图像
image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
# 准备模型输入
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# 设置停止准则
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
# 设置文本流
"""
这段代码首先创建了一个TextStreamer实例,用于处理模型生成的文本流。然后,它在CUDA设备上以mixed precision模式运行模型的generate方法,生成基于输入图像的新tokens。
这个过程包括输入tokens、图像张量、随机采样设置、束搜索参数、文本流处理器、最大新生成tokens数以及停止准则。
"""
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 使用mixed precision进行模型推理
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
do_sample=True,
num_beams = 1,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
# input_token_len = input_ids.shape[1]
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = varyOPTForCausalLM.from_pretrained(model_name)
model.to(device='cuda', dtype=torch.bfloat16)
# image_processor_high = test_transform
image_processor_high = BlipImageEvalProcessor(image_size=1024)
image_token_len = 256
qs = "Provide the OCR results of this image."
# qs = "detect Person in this image.Your answer should be structured precisely according to the category:[xmin,ymin,xmax,ymax] format."
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
conv_mode = "v1"
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
print(prompt)
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
do_sample=True,
num_beams = 1,
streamer=streamer,
max_new_tokens=2048,
stopping_criteria=[stopping_criteria]
)
# input_token_len = input_ids.shape[1]
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from vary.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = varyQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', trust_remote_code=True)
model.to(device='cuda', dtype=torch.bfloat16)
# TODO download clip-vit in huggingface
image_processor = CLIPImageProcessor.from_pretrained("/home/wanglch/projects/Vary/cache/vit-large-patch14", torch_dtype=torch.float16)
image_processor_high = test_transform
use_im_start_end = True
image_token_len = 256
qs = 'Provide the ocr results of this image.'
# qs = 'Detect the red hat in this image.'
# qs = 'Describe this image in within 100 words.'
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=True,
num_beams = 1,
# temperature=0.2,
streamer=streamer,
max_new_tokens=2048,
stopping_criteria=[stopping_criteria]
)
# print(output_ids)
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# # conv.messages[-1][-1] = outputs
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
from .vary_opt import varyOPTModel, varyOPTForCausalLM
from .vary_qwen_vary import varyQwenModel, varyQwenForCausalLM, varyConfig
This diff is collapsed.
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
This diff is collapsed.
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
return batch
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().to(context_tokens.device)
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_id,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
return tokens, attention_mask, position_ids
def get_stop_words_ids(chat_format, tokenizer):
if chat_format == "raw":
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
elif chat_format == "chatml":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set(tokenizer.IMAGE_ST)
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return raw_text, context_tokens
def _decode_default(
tokens: List[int],
*,
stop_words: List[str],
eod_words: List[str],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace',
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate: ", trim_decode_tokens)
end_reason = f"Gen length {len(tokens)}"
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
for eod_word in eod_words:
if eod_word in trim_decode_tokens:
end_reason = f"Gen {eod_word!r}"
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nEnd Reason:", end_reason)
print("\nGenerate: ", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def _decode_chatml(
tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace'
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in eod_token_ids:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
print("\nRaw Generate:", trim_decode_tokens)
print("\nEnd Reason:", end_reason)
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nGenerate:", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str="replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
if chat_format == "chatml":
return _decode_chatml(
tokens,
stop_words=[],
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
context_length=context_length,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
elif chat_format == "raw":
return _decode_default(
tokens,
stop_words=["<|endoftext|>"],
eod_words=["<|endoftext|>"],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
class StopWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
Args:
stop_words_ids (:obj:`List[List[int]]`):
List of list of token ids of stop ids. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, self.eos_token_id] = float(2**15)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
stopped_samples = []
for prev_input_ids_slice in prev_input_ids:
match = False
for stop_token_seq in self.stop_words_ids:
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
return stopped_samples
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from vary.train.train import train
if __name__ == "__main__":
train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment