"configs/datasets/infinitebench/vscode:/vscode.git/clone" did not exist on "aa2dd2b58c65dcb171c2ef328d1f60d1dc7ea27a"
Commit 0bc22e1d authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
import io
import os
import copy
import json
import logging
import torch
import transformers
from typing import List, Optional, Tuple, Union, Dict, Sequence
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.utils.constants import *
"""
这段代码定义了一个名为 BaseDataset 的基础数据集类,它继承自 PyTorch 的 Dataset 类。这个类主要用于处理和存储数据,以便于后续的模型训练。
"""
class BaseDataset(Dataset):
def __init__(
self,
datasets: str, # 是数据集的路径
tokenizer: transformers.PreTrainedTokenizer, # 用于分词的tokenizer
multimodal_cfg: dict # 多模态配置字典,包含不同模态的配置信息
):
"""
初始化BaseDataset类的实例。
参数:
- datasets: str,数据集的路径。
- tokenizer: transformers.PreTrainedTokenizer,用于对文本进行分词的预训练tokenizer。
- multimodal_cfg: dict,包含多种模态配置的字典,例如图像和文本的配置。
"""
super(BaseDataset, self).__init__()
self.tokenizer = tokenizer
self.multimodal_cfg = multimodal_cfg
# 记录使用多少个tokens来表示图像
logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")
def image_processor(self, image):
"""
image_processor 方法是用来处理图像的。这个方法接收一个图像作为输入,然后根据 multimodal_cfg 中的配置信息来处理这个图像。
首先,获取两个图像处理器:processor 和 processor_high。然后,根据 multimodal_cfg 中的 image_aspect_ratio 配置来决定如何处理图像。
如果 image_aspect_ratio 是 "keep",那么会保持图像的长宽比;如果 image_aspect_ratio 是 "pad",那么会将图像扩展到正方形;
如果 image_aspect_ratio 是其他值,那么会直接处理图像。最后,使用 processor_high 来处理复制的图像,并返回处理后的两个图像。
"""
"""
处理图像,通过两个处理器对图像进行处理:第一个处理器通常为预训练的模型(如CLIP、VIT),第二个处理器通常为设计的图像编码器(如SAM、SWIN、CNN)。
参数:
- image: 图像对象,待处理的图像。
返回值:
- image: 经过第一个处理器处理后的图像张量。
- image_high: 经过第二个处理器处理后的图像张量。
"""
processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit)
processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn)
image_high = image.copy()
# TODO the 'keep', 'padding' only used for the first processor
# 根据配置决定如何处理图像以保持纵横比或填充到固定大小
if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
# 保持纵横比处理图像
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
# 将图像转换为张量
image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
# 通过填充方式处理图像,保持为正方形
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img) # for simpler box processing
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img) # for simpler box processing
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
# 使用第二个处理器处理图像
image_high = processor_high(image_high)
return image, image_high
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
pass
\ No newline at end of file
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib
# 这段代码定义了一个名为 CaptionDataset 的类,它继承自 BaseDataset 类。
# 这个类主要用于处理和存储对话格式的数据集,以便于后续的模型训练。
class CaptionDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
"""CaptionDataset 类的 __init__ 方法接收三个参数:datasets、tokenizer 和 multimodal_cfg。
其中,datasets 是数据集的名称,tokenizer 是一个预训练的分词器,multimodal_cfg 是一个字典,包含了多模态配置信息。"""
def __init__(self, datasets, tokenizer, multimodal_cfg):
# 调用父类初始化方法
super(CaptionDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
"""在 __init__ 方法中,首先调用父类的初始化方法,然后设置默认的对话模板为 "default"。
然后,对每个数据集,读取其注释数据和图像路径,并将这些数据添加到 list_data_dict 和 list_image_path 列表中。
然后,将这两个列表打乱并重新组合,最后将这两个列表以及一些分词器生成的 token 保存为类的属性。"""
# 设置默认对话模板
conversation_lib.default_conversation = conversation_lib.conv_templates["default"]
# 开始格式化输入数据
logging.warning("Formatting inputs into conversation type: v0-fixed")
logging.warning("Loading data...")
# 初始化用于存储数据和图像路径的列表
list_data_dict = []
list_image_path = []
# 遍历数据集名称,加载数据
for name in datasets.split("+"):
dataset = CONVERSATION_DATA[name] # in vary.utils # 从预定义的数据字典中获取数据集
# 加载注释数据
data_path = dataset['annotations']
data = json.load(open(data_path, "r"))
list_data_dict.extend(data)
# 获取图像路径,并与注释数据对应
image_path = dataset['images']
list_image_path.extend([image_path] * len(data))
# 输出数据加载信息
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
# 确保数据和图像路径的数量一致
assert len(list_data_dict) == len(list_image_path)
# 输出总数据量
logging.warning(f"{len(list_data_dict)} conversations in total.")
# 打乱并重新组合数据和图像路径
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
# 为图像标记生成对应的token ID
self.im_patch_token, self.im_start_token, self.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
def multimodal_processor(self, sources):
"""
处理多模态源数据,将图像标记替换为一定长度的图像标记序列。
:param sources: 包含多个源数据的列表,每个源数据是一个字典列表,其中第一个字典代表图像信息,
其他字典代表与图像相关联的文本信息。
:return: 经过图像标记替换处理后的源数据列表。
"""
for source in sources:
# 为每个源数据的图像信息设置默认的图像标记
source[0]['value'] = DEFAULT_IMAGE_TOKEN
for sentence in source:
# 根据配置,生成替换用的图像标记序列
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
if self.multimodal_cfg['use_im_start_end']:
# 如果使用图像开始和结束标记,则在图像标记序列两端添加开始和结束标记
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# 使用生成的图像标记序列替换句子中的默认图像标记
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
"""
对字符串列表进行分词处理,以供模型输入。
此函数处理一个字符串列表,对每个字符串进行分词,填充或截断令牌到最大长度,并创建输入掩码。
它还处理特定的分词情况,如确保每个输入的最后一个令牌是特定令牌。
参数:
- strings (list): 需要分词的字符串列表。
返回:
- dict: 包含以下键值对的字典:
- input_ids (list): 表示每个输入字符串的令牌ID列表。
- labels (list): 用作标签的令牌ID列表,通常与input_ids相同。
- input_ids_lens (list): 分词后的字符串长度列表。
- labels_lens (list): 标签长度列表。
"""
# 对字符串列表中的每个字符串进行分词并准备模型输入
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
# 从分词结果中提取输入ID和标签
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
# 确保每个输入的最后一个令牌是特定令牌(此处为2)
for idx, ii in enumerate(input_ids):
if ii[-1] != 2:
input_ids[idx][-1] = 2
labels[idx][-1] = 2
# 计算每个输入的非填充令牌长度
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
# 返回包含所有必需处理数据的字典
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
"""
隐藏目标序列中特定部分的数据。
参数:
- target: 待处理的目标序列,通常是一个序列化的文本或音频表示。
- tokenized_lens: 分词后每个序列的长度列表。
- speakers: 对应于每个序列的说话者信息列表。
返回值:
- 修改后的target,其中根据说话者信息和长度信息进行了掩码处理。
"""
# cur_idx = 0
# 初始化当前索引为第一个序列的长度
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度
target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分
# 遍历剩余的序列长度和说话者信息,进行相应的掩码处理
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx:tokenized_len] = IGNORE_INDEX # 如果说话者是“human”,隐藏该部分目标序列
cur_idx += tokenized_len # 更新当前索引位置,准备处理下一个序列
def _add_speaker_and_signal(self, header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
"""
在每个回合中添加说话者和开始/结束信号。
此方法处理一个句子列表,识别每个句子的说话者,并向句子添加开始和结束信号。
如果`get_conversation`为真,它将所有句子连接成一个完整的对话字符串。
参数:
- header (str): 对话的初始消息或标题。
- source (list): 每个句子是一个字典,至少包含'from'键和'value'键。
'from'键指定说话者,'value'键指定所说的话。
- get_conversation (bool, 可选): 标志,表示是否将所有句子合并为一个单一的对话字符串。
默认为True。
返回:
- str: 添加了说话者并带有适当开始和结束信号的最终对话字符串。
如果`get_conversation`为False,则返回None。
"""
BEGIN_SIGNAL = "</s>"
END_SIGNAL = "\n"
conversation = header # 初始化对话,附带标题
# 处理源列表中的每个句子
for sentence in source:
from_str = sentence["from"] # 获取说话者 获取说话信息
# 将说话者名称映射到预定义的角色
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
else:
from_str = conversation_lib.default_conversation.roles[1]
sentence["value"] = sentence["value"] + END_SIGNAL # 向句子添加结束信号
if get_conversation:
conversation += sentence["value"] # 若需要,将句子拼接到对话中
conversation += BEGIN_SIGNAL
return conversation
def token_processor(self, sources):
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
multimodal_processor 方法是用来处理多模态数据的。这个方法接收一个名为 sources 的参数,这个参数应该是一个包含多个数据源的列表。
每个数据源都应该是一个字典,包含 "value" 和 "sentence" 这两个键。
这个方法会遍历每个数据源,将 "value" 键的值替换为特定的 token,然后返回处理后的数据源列表。
"""
"""
处理输入的对话列表,进行标记化和掩码处理。
参数:
sources: 一个包含多个会话列表的列表,每个会话列表是一个字典列表,其中每个字典包含"value"和"sentence"键。
返回:
一个字典,包含输入ID和标签(被掩码处理的人类语言)的列表。
"""
# add end signal and concatenate together
# 为每个会话添加信号并在末尾连接
conversations = []
for source in sources:
header = ''
conversation = self._add_speaker_and_signal(header, source) # 添加说话者和信号到会话
conversations.append(conversation)
# 对会话进行标记化
conversations_tokenized = self._tokenize_fn(conversations) # 标记化会话文本
input_ids = conversations_tokenized["input_ids"] # 获取输入ID
# 创建目标的深拷贝,并掩码处理人类语言单词
targets = copy.deepcopy(input_ids) # 创建输入ID的深拷贝作为目标
for target, source in zip(targets, sources): # 遍历目标和源,进行掩码处理
tokenized_lens = self._tokenize_fn([header] + [s["value"] for s in source])["input_ids_lens"] # 计算每个说话者的标记长度
speakers = [sentence["from"] for sentence in source] # 获取每个句子的说话者
self._mask_targets(target, tokenized_lens, speakers) # 根据说话者信息掩码目标序列
return dict(input_ids=input_ids, labels=targets) # 返回处理后的输入和目标
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
"""
根据索引i获取数据集中的一个项目。
参数:
- i: int, 数据集中项的索引。
返回值:
- Dict[str, torch.Tensor]: 包含项目数据的字典,其中可能包括'image', 'image_high', 'input_ids', 'labels'键。
"""
# 深拷贝数据以避免修改原始数据
data = copy.deepcopy(self.list_data_dict[i])
# 处理包含图像数据的项
if isinstance(data, dict):
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
# TODO this is a bug, because some json has wrong path
# 尝试打开并处理图像
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_path+image_file}.')
return self.__getitem__(0)
# 尝试对图像进行进一步处理
try:
image, image_high = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
conversations = self.multimodal_processor([data["conversations"]])
else:
# 对于不包含图像数据的项
conversations = [data]
# align with fastchat & llava here, put the conversation into a list for tokenization
# 对会话数据进行标记化处理
data_dict = self.token_processor(conversations)
# 简化数据字典结构
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# 如果数据包含图像,添加图像数据到字典中
if isinstance(data, dict) and 'image' in data:
data_dict['image'] = [image]
data_dict['image_high'] = [image_high]
else:
# 为不包含图像的数据项填充默认的图像张量
crop_size = self.multimodal_cfg['image_processor'].crop_size
data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
# TODO sam is 1024*1024
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib
class ConversationDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
"""
这段代码是一个名为ConversationDataset的Python类,它继承自BaseDataset类。
这个类主要用于处理对话格式的数据集,特别是在多模态配置(multimodal_cfg)的情况下。
这个方法中,首先调用了父类的构造函数来初始化一些基本的属性。然后,它设置了默认的对话模板,并打印了一些日志信息。
接着,它遍历datasets,加载数据和图片路径,并将它们添加到列表中。最后,它将数据和图片路径的列表打乱,并将它们赋值给类的属性。
"""
def __init__(self, datasets, tokenizer, multimodal_cfg):
# 调用父类的构造函数
super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
# 初始化会话格式为mpt
conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
# 输出警告信息,说明正在格式化输入为mpt-fixed类型的会话
logging.warning("Formatting inputs into conversation type: mpt-fixed")
logging.warning("Loading data...")
# 初始化数据字典列表和图片路径列表
list_data_dict = []
list_image_path = []
# 遍历数据集名称,加载每个数据集
for name in datasets.split("+"):
# for name in vary_data_dict[name_all]:
dataset = CONVERSATION_DATA[name]
data_path = dataset['annotations']
data = json.load(open(data_path, "r")) # 加载注释数据
list_data_dict.extend(data) # 将数据添加到列表中
image_path = dataset['images']
list_image_path.extend([image_path] * len(data)) # 为每个数据项添加图片路径
# 输出每个数据集提供的会话数量
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
# 确保数据项和图片路径数量一致
assert len(list_data_dict) == len(list_image_path)
# 输出总共有多少会话
logging.warning(f"{len(list_data_dict)} conversations in total.")
# 打乱数据和图片路径的组合,并重新分配
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
# 设置图像标记相关的令牌
self.im_patch_token = 151859
self.im_start_token = 151857
self.im_end_token = 151858
def multimodal_processor(self, sources):
"""
这段代码是一个名为`ConversationDataset`的Python类,它继承自`BaseDataset`类。
这个类主要用于处理对话格式的数据集,特别是在多模态配置(`multimodal_cfg`)的情况下。
multimodal_processor方法:这个方法用于处理多模态的源数据。它遍历sources,并检查每个源数据的第一个元素的'value'是否包含默认的图片标记。
如果不包含,它会抛出一个断言错误。
"""
"""
处理多模态源数据的方法。
此方法遍历输入的源数据列表,对每条源数据进行处理。首先,检查是否在多模态配置中指定了在对话前添加图片标记(`sep_image_conv_front`),
如果指定且第一个元素的'value'中不包含默认的图片标记,则抛出断言错误。然后,对于每个源数据中的句子,
将默认图片标记替换为一定长度的图片标记字符串,如果在多模态配置中启用了使用图片开始和结束标记(`use_im_start_end`),
则在图片标记字符串前后添加图片开始和结束标记。
参数:
- sources: 包含多模态源数据的列表。每个源数据是一个列表,其中第一个元素是特殊的源标记,后续元素是具体的句子信息。
返回值:
- 处理后的多模态源数据列表。
"""
for source in sources:
# 检查是否在对话前添加图片标记,并处理图片标记
if self.multimodal_cfg['sep_image_conv_front']:
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
# 重新格式化源标记,添加角色信息
for sentence in source:
# 准备替换的图片标记字符串
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
# if self.multimodal_cfg['use_im_start_end']:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# 如果启用了使用图片开始和结束标记,格式化图片标记字符串
sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
"""
对字符串列表进行分词处理。
此函数接受一个字符串列表,并使用实例关联的分词器对其进行分词。
返回一个字典,其中包含分词后的输入、标签以及它们的长度。
参数:
- strings (list): 需要分词的字符串列表。
返回:
- dict: 包含以下键值对的字典:
- input_ids (list): 分词后输入ID的列表。
- labels (list): 分词后标签ID的列表(如果适用,否则与input_ids相同)。
- input_ids_lens (list): 分词后输入的长度列表。
- labels_lens (list): 分词后标签的长度列表。
"""
# 对列表中的每个字符串进行分词并存储结果
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
# 从每个分词后的字符串中提取input_ids
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
# 计算每个分词字符串的长度(不包括填充令牌)
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
# 返回包含所有相关分词数据的字典
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
"""
隐藏目标序列中特定部分的数据。
参数:
- target: 待处理的目标序列,通常是一个序列化的文本或音频表示。
- tokenized_lens: 分词后每个序列的长度列表。
- speakers: 对应于每个序列的说话者信息列表。
返回值:
- 修改后的target,其中特定部分被隐藏。
"""
# cur_idx = 0
# 初始化当前索引为第一个序列的长度
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度
target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分
# 遍历剩余的序列长度和说话者信息,隐藏相应部分的目标序列
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX # 隐藏human说话者的部分,偏移量+2可能用于标识说话者
cur_idx += tokenized_len # 更新当前索引以处理下一个序列
def token_processor(self, sources):
"""
处理对话源数据,将其转换为适用于模型输入的格式,包括标记化和目标掩码处理。
参数:
sources - 一个包含多个对话来源的列表,每个对话来源自身是一个列表,其中包含了多个句子字典,
每个字典包含"from"和"value"键,分别指明句子的来源("human"或"gpt")和句子内容。
返回值:
一个字典,包含两个键:
- input_ids: 经过标记化处理后的输入序列ID,格式为PyTorch张量。
- labels: 对应的掩码处理后的目标序列ID,用于模型训练,格式同input_ids。
"""
# 复制默认对话对象并定义角色
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates # 应用对话模板
conversations = []
for i, source in enumerate(sources):
# 如果首个消息不是来自人类,则跳过
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
# 初始化对话消息
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
# 确保角色匹配预期顺序
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
# 对话内容标记化
input_ids = self.tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
).input_ids
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
targets = input_ids.clone() # 确保使用的分隔符风格匹配预期
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# 目标掩码处理
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
# 计算非填充token的总数
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # 系统 + 用户 + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # 用户 + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
instruction_len = len(self.tokenizer(parts[0]).input_ids)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
# 警告处理:标记化结果与预期长度不匹配
if cur_len < self.tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
"""
根据索引i获取数据项,并进行预处理。如果数据项包含图像,则同时进行图像预处理。
参数:
i: 索引值,用于指定要获取的数据项。
返回值:
一个字典,包含经过处理的文本和图像数据,键包括input_ids, labels, image, image_high等。
"""
# 深拷贝数据项以避免修改原始数据
data = copy.deepcopy(self.list_data_dict[i])
# 检查数据项是否为字典,并处理其中的图像数据
if isinstance(data, dict):
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
# 尝试打开并转换图像
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_path + image_file}.')
return self.__getitem__(0)
# 尝试打开并转换图像
try:
image, image_1 = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
# 对对话数据进行预处理
conversations = self.multimodal_processor([data["conversations"]])
else:
# 如果数据项不是字典,则直接将其作为对话处理
conversations = [data]
# 使用token处理器处理对话数据
# align with fastchat & llava here, put the conversation into a list for tokenization
data_dict = self.token_processor(conversations)
# 精简处理后的数据字典,只保留input_ids和labels
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# 如果数据项包含图像,将图像数据添加到数据字典中
if isinstance(data, dict) and 'image' in data:
data_dict['image'] = [image]
data_dict['image_high'] = [image_1]
else:
# 如果数据项不包含图像,为图像数据添加占位符
crop_size = self.multimodal_cfg['image_processor'].crop_size
data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
# 这段代码是用于加载和评估一个预训练模型的。它使用了Hugging Face的Transformers库,以及PIL库来处理图像。
def load_image(image_file):
"""
加载图像。
参数:
- image_file: 图像的文件路径或URL。
返回值:
- image: 打开的图像,转换为RGB模式。
"""
# 判断图像文件是否是URL
if image_file.startswith('http') or image_file.startswith('https'):
# 从URL下载图像
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
# 从本地文件系统打开图像
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
"""
评估模型性能。
参数:
- args: 一个包含模型名称和图像文件路径等信息的参数对象。
说明:
此函数加载指定的模型和对应的tokenizer,将输入的图像转换为文本描述。
使用关键词停止准则以防止生成不必要的文本。
"""
# Model
# 初始化模型和tokenizer
disable_torch_init() # 禁用PyTorch的默认初始化
model_name = os.path.expanduser(args.model_name) # 解析模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # 解析模型名称
model = varyOPTForCausalLM.from_pretrained(model_name) # 从预训练加载模型
# 将模型转移到GPU并设置数据类型
model.to(device='cuda', dtype=torch.bfloat16)
# 设置图像处理和token长度
image_processor_high = test_transform # 定义图像预处理
image_token_len = 256 # 图像token长度
# 构建输入prompt
prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
inputs = tokenizer([prompt]) # 对prompt进行tokenization
# 加载输入图像
image = load_image(args.image_file)
image_1 = image.copy()
# 加载输入图像
image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
# 准备模型输入
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# 设置停止准则
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
# 设置文本流
"""
这段代码首先创建了一个TextStreamer实例,用于处理模型生成的文本流。然后,它在CUDA设备上以mixed precision模式运行模型的generate方法,生成基于输入图像的新tokens。
这个过程包括输入tokens、图像张量、随机采样设置、束搜索参数、文本流处理器、最大新生成tokens数以及停止准则。
"""
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 使用mixed precision进行模型推理
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
do_sample=True,
num_beams = 1,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
# input_token_len = input_ids.shape[1]
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = varyOPTForCausalLM.from_pretrained(model_name)
model.to(device='cuda', dtype=torch.bfloat16)
# image_processor_high = test_transform
image_processor_high = BlipImageEvalProcessor(image_size=1024)
image_token_len = 256
qs = "Provide the OCR results of this image."
# qs = "detect Person in this image.Your answer should be structured precisely according to the category:[xmin,ymin,xmax,ymax] format."
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
conv_mode = "v1"
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
print(prompt)
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
do_sample=True,
num_beams = 1,
streamer=streamer,
max_new_tokens=2048,
stopping_criteria=[stopping_criteria]
)
# input_token_len = input_ids.shape[1]
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from vary.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = varyQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', trust_remote_code=True)
model.to(device='cuda', dtype=torch.bfloat16)
# TODO download clip-vit in huggingface
image_processor = CLIPImageProcessor.from_pretrained("/home/wanglch/projects/Vary/cache/vit-large-patch14", torch_dtype=torch.float16)
image_processor_high = test_transform
use_im_start_end = True
image_token_len = 256
qs = 'Provide the ocr results of this image.'
# qs = 'Detect the red hat in this image.'
# qs = 'Describe this image in within 100 words.'
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=True,
num_beams = 1,
# temperature=0.2,
streamer=streamer,
max_new_tokens=2048,
stopping_criteria=[stopping_criteria]
)
# print(output_ids)
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# # conv.messages[-1][-1] = outputs
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
from .vary_opt import varyOPTModel, varyOPTForCausalLM
from .vary_qwen_vary import varyQwenModel, varyQwenForCausalLM, varyConfig
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.models.opt.configuration_opt import OPTConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
_CONFIG_FOR_DOC = "OPTConfig"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
# Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
# SequenceClassification docstring
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
_SEQ_CLASS_EXPECTED_LOSS = 1.71
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
# QuestionAnswering docstring
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 7.41
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
"facebook/opt-13b",
"facebook/opt-30b",
# See all OPT models at https://huggingface.co/models?filter=opt
]
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
OPT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`OPTConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
OPT_START_DOCSTRING,
)
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (OPTDecoder)):
module.gradient_checkpointing = value
OPT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class OPTDecoder(OPTPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
Args:
config: OPTConfig
"""
def __init__(self, config: OPTConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.word_embed_proj_dim, self.padding_idx
)
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size
)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(
config.hidden_size, config.word_embed_proj_dim, bias=False
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(
config.word_embed_proj_dim, config.hidden_size, bias=False
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
past_key_values_length=past_key_values_length,
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if query_embeds is not None:
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@add_start_docstrings(
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
OPT_START_DOCSTRING,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig):
super().__init__(config)
self.decoder = OPTDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
query_embeds=query_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(
config.word_embed_proj_dim, config.vocab_size, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
query_embeds=query_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
loss = loss_fct(
shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
)
if reduction == "none":
loss = loss.view(shift_logits.size(0), -1).sum(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids=None,
query_embeds=None,
past=None,
attention_mask=None,
use_cache=None,
**kwargs,
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
if input_ids is not None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
query_embeds = None
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids,
"query_embeds": query_embeds,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
HistoryType,
make_context,
decode_tokens,
get_stop_words_ids,
StopWordsLogitsProcessor,
)
# from .visual import VisionTransformer
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "qwen"
_CONFIG_FOR_DOC = "QWenConfig"
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
"""
_SENTINEL = object()
_ERROR_STREAM_IN_CHAT = """\
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
"""
apply_rotary_emb_func = None
rms_norm = None
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.scale_attn_weights = True
self.projection_size = config.kv_channels * config.num_attention_heads
assert self.projection_size % config.num_attention_heads == 0
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(
config.hidden_size, self.projection_size, bias=not config.no_bias
)
self.is_fp32 = not (config.bf16 or config.fp16)
self.bf16 = config.bf16
if config.rotary_pct == 1.0:
self.rotary_ndims = None
else:
assert config.rotary_pct < 1
self.rotary_ndims = int(
self.hidden_size_per_attention_head * config.rotary_pct
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else self.hidden_size_per_attention_head
)
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [
math.log(i, self.seq_length) if i > self.seq_length else 1
for i in range(1, 32768)
]
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self._ntk_cached = 1.0
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask = self.bias[
# :, :, key_length - query_length : key_length, :key_length
# ]
# mask_value = torch.finfo(attn_weights.dtype).min
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
# attn_weights.device
# )
# attn_weights = torch.where(
# causal_mask, attn_weights.to(attn_weights.dtype), mask_value
# )
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
def _upcast_and_reordered_attn(
self, query, key, value, attention_mask=None, head_mask=None
):
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
attn_weights = torch.empty(
bsz * num_heads,
q_seq_len,
k_seq_len,
dtype=torch.float32,
device=query.device,
)
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
-1, dk, k_seq_len
)
attn_weights = torch.baddbmm(
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[
:, :, key_length - query_length : key_length, :key_length
]
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
attn_weights.device
)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if attn_weights.dtype != torch.float32:
raise RuntimeError(
"Error with upcasting, attn_weights does not have dtype torch.float32"
)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
kv_seq_len = hidden_states.size()[1]
if layer_past:
# layer past[0] shape: bs * seq_len * head_num * dim
kv_seq_len += layer_past[0].shape[1]
if (
self.use_dynamic_ntk
and kv_seq_len == hidden_states.size()[1]
and not self.training
):
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
self._ntk_cached = ntk_alpha
else:
ntk_alpha = self._ntk_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
hidden_states.device
)
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
cur_len = query.shape[1]
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
else:
present = None
if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weight,)
return outputs
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w2 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
ff_dim_in = config.intermediate_size // 2
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def forward(self, hidden_states):
a1 = self.w1(hidden_states)
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
output = self.c_proj(intermediate_parallel)
return output
class QWenBlock(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.bf16 = config.bf16
self.ln_1 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config)
self.ln_2 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn(
layernorm_output,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
residual = hidden_states
layernorm_input = attn_output + residual
layernorm_output = self.ln_2(layernorm_input)
residual = layernorm_input
mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs
class QWenPreTrainedModel(PreTrainedModel):
config_class = QWenConfig
base_model_prefix = "transformer"
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "c_proj.weight":
p.data.normal_(
mean=0.0,
std=(
self.config.initializer_range
/ math.sqrt(2 * self.config.num_hidden_layers)
),
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, QWenModel):
module.gradient_checkpointing = value
class QWenModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.gradient_checkpointing = False
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.emb_dropout_prob)
self.h = nn.ModuleList(
[
QWenBlock(
config,
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
)
# self.visual = VisionTransformer(**config.visual)
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
# bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
# eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
# assert (bos_pos[0] == eos_pos[0]).all()
# img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
# images = []
# for i, a, b in img_pos:
# image = input_ids[i][a + 1 : b - 1].tolist()
# image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
# images.append(bytes(image).decode('utf-8'))
# images = self.visual.encode(images)
# assert images.shape[0] == len(images)
# else:
# images = None
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_length
)
hidden_states = inputs_embeds
hidden_states = self.drop(hidden_states)
# if images is not None:
# for idx, (i, a, b) in enumerate(img_pos):
# hidden_states[i][a + 1 : b] = images[idx]
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class QWenLMHeadModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past_key_values
)
def chat(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
append_history: bool = True,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = self.generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=self.generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
self.generation_config.chat_format, tokenizer
))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids = stop_words_ids,
return_dict_in_generate = False,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=self.generation_config.chat_format,
verbose=False,
errors='replace'
)
if append_history:
history.append((query, response))
return response, history
def chat_stream(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
**kwargs,
) -> Generator[str, Any, None]:
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = self.generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=self.generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
self.generation_config.chat_format, tokenizer
))
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=self.generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
def stream_generator():
outputs = []
for token in self.generate_stream(
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs):
outputs.append(token.item())
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
return stream_generator()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
# Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
if stop_words_ids is None and generation_config is not None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is None:
stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=self.generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
return super().generate(
inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs,
)
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
if importlib.util.find_spec("einops") is None:
raise RuntimeError("einops is required for Rotary Embedding")
self._rotary_pos_emb_cache = None
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0
def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
seqlen = max_seq_len + offset
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
/ self.dim
)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
from einops import rearrange
self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
if apply_rotary_emb_func is not None and t.is_cuda:
t_ = t.float()
freqs = freqs.squeeze(0).squeeze(1)
cos = freqs[:, : freqs.shape[-1] // 2].cos()
sin = freqs[:, : freqs.shape[-1] // 2].sin()
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output
else:
rot_dim = freqs.shape[-1]
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
if rms_norm is not None and x.is_cuda:
return rms_norm(x, self.weight, self.eps)
else:
output = self._norm(x.float()).type_as(x)
return output * self.weight
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
return batch
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().to(context_tokens.device)
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_id,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
return tokens, attention_mask, position_ids
def get_stop_words_ids(chat_format, tokenizer):
if chat_format == "raw":
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
elif chat_format == "chatml":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set(tokenizer.IMAGE_ST)
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return raw_text, context_tokens
def _decode_default(
tokens: List[int],
*,
stop_words: List[str],
eod_words: List[str],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace',
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate: ", trim_decode_tokens)
end_reason = f"Gen length {len(tokens)}"
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
for eod_word in eod_words:
if eod_word in trim_decode_tokens:
end_reason = f"Gen {eod_word!r}"
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nEnd Reason:", end_reason)
print("\nGenerate: ", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def _decode_chatml(
tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace'
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in eod_token_ids:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
print("\nRaw Generate:", trim_decode_tokens)
print("\nEnd Reason:", end_reason)
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nGenerate:", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str="replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
if chat_format == "chatml":
return _decode_chatml(
tokens,
stop_words=[],
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
context_length=context_length,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
elif chat_format == "raw":
return _decode_default(
tokens,
stop_words=["<|endoftext|>"],
eod_words=["<|endoftext|>"],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
class StopWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
Args:
stop_words_ids (:obj:`List[List[int]]`):
List of list of token ids of stop ids. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, self.eos_token_id] = float(2**15)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
stopped_samples = []
for prev_input_ids_slice in prev_input_ids:
match = False
for stop_token_seq in self.stop_words_ids:
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
return stopped_samples
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Tokenization classes for QWen."""
import base64
import logging
import os
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union
import tiktoken
from transformers import PreTrainedTokenizer, AddedToken
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
SPECIAL_TOKENS = (
ENDOFTEXT,
IMSTART,
IMEND,
) + EXTRAS
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
**kwargs,
):
super().__init__(**kwargs)
self.errors = errors # how to handle errors in decoding
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
self.special_tokens = {
token: index
for index, token in enumerate(
SPECIAL_TOKENS, start=len(self.mergeable_ranks)
)
}
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {
v: k for k, v in self.mergeable_ranks.items()
} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})
self.tokenizer = enc # type: tiktoken.Encoding
self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks
def convert_tokens_to_ids(
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
if not special_tokens and new_tokens:
raise ValueError('Adding regular tokens is not supported')
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS:
raise ValueError('Adding unknown special tokens is not supported')
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).
Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.
Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.
Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)
# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.decoder[t])
return tokens
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
@property
def vocab_size(self):
return self.tokenizer.n_vocab
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")
def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import cv2
import numpy as np
import torch
# from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
# @classmethod
# def from_config(cls, cfg=None):
# return cls()
# def build(self, **kwargs):
# cfg = OmegaConf.create(kwargs)
# return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
"""
same output as PIL.ImageOps.autocontrast
"""
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
"""
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
"""
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
"""
like PIL, rotate by degree, not radians
"""
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
"""
same output as PIL.ImageOps.posterize
"""
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
"""
same output as PIL.ImageEnhance.Color
"""
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = np.float32(
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
) * factor + np.float32([[0.114], [0.587], [0.299]])
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = (
np.array([(el - mean) * factor + mean for el in range(256)])
.clip(0, 255)
.astype(np.uint8)
)
out = table[img]
return out
def brightness_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
"""
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
"""
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def posterize_func(img, bits):
"""
same output as PIL.ImageOps.posterize
"""
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level,)
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level,)
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
"Identity": identity_func,
"AutoContrast": autocontrast_func,
"Equalize": equalize_func,
"Rotate": rotate_func,
"Solarize": solarize_func,
"Color": color_func,
"Contrast": contrast_func,
"Brightness": brightness_func,
"Sharpness": sharpness_func,
"ShearX": shear_x_func,
"TranslateX": translate_x_func,
"TranslateY": translate_y_func,
"Posterize": posterize_func,
"ShearY": shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
"Identity": none_level_to_args,
"AutoContrast": none_level_to_args,
"Equalize": none_level_to_args,
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
"Solarize": solarize_level_to_args(MAX_LEVEL),
"Color": enhance_level_to_args(MAX_LEVEL),
"Contrast": enhance_level_to_args(MAX_LEVEL),
"Brightness": enhance_level_to_args(MAX_LEVEL),
"Sharpness": enhance_level_to_args(MAX_LEVEL),
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"Posterize": posterize_level_to_args(MAX_LEVEL),
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
class VideoRandomAugment(object):
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
self.N = N
self.M = M
self.p = p
self.tensor_in_tensor_out = tensor_in_tensor_out
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
return [(op, self.M) for op in sampled_ops]
def __call__(self, frames):
assert (
frames.shape[-1] == 3
), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
if self.tensor_in_tensor_out:
frames = frames.numpy().astype(np.uint8)
num_frames = frames.shape[0]
ops = num_frames * [self.get_random_ops()]
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
frames = torch.stack(
list(map(self._aug, frames, ops, apply_or_not)), dim=0
).float()
return frames
def _aug(self, img, ops, apply_or_not):
for i, (name, level) in enumerate(ops):
if not apply_or_not[i]:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return torch.from_numpy(img)
# if __name__ == "__main__":
# a = RandomAugment()
# img = np.random.randn(32, 32, 3)
# a(img)
class BlipImageTrainProcessor(BlipImageBaseProcessor):
def __init__(
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
# transforms.RandomHorizontalFlip(),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
# "AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
# "ShearX",
# "ShearY",
# "TranslateX",
# "TranslateY",
# "Rotate",
],
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
# if __name__ == "__main__":
# a = BlipImageTrainProcessor(image_size=1024)
# # img = np.random.randn(1024, 1024, 3)
# # x = torch.zeros(1024, 1024, 3)
# x = Image.open("/data/codes/vary-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB")
# print(x.size)
# y = a(x)
# print(y.size())
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
# Implements image augmentation
import albumentations as alb
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def alb_wrapper(transform):
def f(im):
return transform(image=np.asarray(im))["image"]
return f
class Erosion(alb.ImageOnlyTransform):
"""
Apply erosion operation to an image.
Erosion is a morphological operation that shrinks the white regions in a binary image.
Args:
scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
If an integer is provided, a square kernel of that size will be used.
If a tuple or list is provided, it should contain two integers representing the minimum
and maximum sizes for the erosion kernel.
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
p (float, optional): The probability of applying this transformation. Default is 0.5.
Returns:
numpy.ndarray: The transformed image.
"""
def __init__(self, scale, always_apply=False, p=0.5):
super().__init__(always_apply=always_apply, p=p)
if type(scale) is tuple or type(scale) is list:
assert len(scale) == 2
self.scale = scale
else:
self.scale = (scale, scale)
def apply(self, img, **params):
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
)
img = cv2.erode(img, kernel, iterations=1)
return img
class Dilation(alb.ImageOnlyTransform):
"""
Apply dilation operation to an image.
Dilation is a morphological operation that expands the white regions in a binary image.
Args:
scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
If an integer is provided, a square kernel of that size will be used.
If a tuple or list is provided, it should contain two integers representing the minimum
and maximum sizes for the dilation kernel.
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
p (float, optional): The probability of applying this transformation. Default is 0.5.
Returns:
numpy.ndarray: The transformed image.
"""
def __init__(self, scale, always_apply=False, p=0.5):
super().__init__(always_apply=always_apply, p=p)
if type(scale) is tuple or type(scale) is list:
assert len(scale) == 2
self.scale = scale
else:
self.scale = (scale, scale)
def apply(self, img, **params):
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
)
img = cv2.dilate(img, kernel, iterations=1)
return img
class Bitmap(alb.ImageOnlyTransform):
"""
Apply a bitmap-style transformation to an image.
This transformation replaces all pixel values below a certain threshold with a specified value.
Args:
value (int, optional): The value to replace pixels below the threshold with. Default is 0.
lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
p (float, optional): The probability of applying this transformation. Default is 0.5.
Returns:
numpy.ndarray: The transformed image.
"""
def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
super().__init__(always_apply=always_apply, p=p)
self.lower = lower
self.value = value
def apply(self, img, **params):
img = img.copy()
img[img < self.lower] = self.value
return img
train_transform = alb_wrapper(
alb.Compose(
[
Bitmap(p=0),
alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.02),
alb.Affine(shear={"x": (0, 3), "y": (-3, 0)}, cval=(255, 255, 255), p=0.03),
alb.ShiftScaleRotate(
shift_limit_x=(0, 0.04),
shift_limit_y=(0, 0.03),
scale_limit=(-0.15, 0.03),
rotate_limit=2,
border_mode=0,
interpolation=2,
value=(255, 255, 255),
p=0.03,
),
alb.GridDistortion(
distort_limit=0.05,
border_mode=0,
interpolation=2,
value=(255, 255, 255),
p=0.04,
),
alb.Compose(
[
alb.Affine(
translate_px=(0, 5), always_apply=True, cval=(255, 255, 255)
),
alb.ElasticTransform(
p=1,
alpha=50,
sigma=120 * 0.1,
alpha_affine=120 * 0.01,
border_mode=0,
value=(255, 255, 255),
),
],
p=0.04,
),
alb.RandomBrightnessContrast(0.1, 0.1, True, p=0.03),
alb.ImageCompression(95, p=0.07),
alb.GaussNoise(20, p=0.08),
alb.GaussianBlur((3, 3), p=0.03),
alb.Resize(1024, 1024),
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
)
)
test_transform = alb_wrapper(
alb.Compose(
[
alb.Resize(1024, 1024),
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
)
)
# if __name__ == '__main__':
# from PIL import Image
# image = Image.open('/data/hypertext/ucaswei/codes_new/show/49.jpg').convert('RGB')
# # image = np.array(image)
# for i in range(100):
# image1 = train_transform(image)
# image1 = Image.fromarray(np.uint8(image1))
# image1.save('/data/hypertext/ucaswei/codes_new/aug/' + str(i) + '.jpg')
# mm_projector_1 = nn.Linear(1024, 256)
# x = torch.zeros(2, 3, 1024, 1024)
# with torch.no_grad():
# y = model(x)
# print(y.shape)
# y = mm_projector_1(y.permute(0,2,1))
# print(y.shape)
# print(y.permute(0,2,1).shape)
\ No newline at end of file
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM, \
CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from vary.utils.constants import *
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from transformers import OPTConfig, OPTModel, OPTForCausalLM
from vary.model.plug.transforms import train_transform, test_transform
class varyConfig(OPTConfig):
model_type = "vary"
class varyOPTModel(OPTModel):
config_class = varyConfig
def __init__(self, config: OPTConfig):
super(varyOPTModel, self).__init__(config)
self.vision_tower_2 = build_sam_vit_b()
self.mm_projector = nn.Linear(1024, 768)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
# 224*224
# image_processor do not used in opt
image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14')
# 1024*1024
image_processor_high = train_transform
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor=image_processor,
image_processor_high=image_processor_high,
image_token_len=image_token_len,
)
def embed_tokens(self, x):
return self.get_input_embeddings()(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
# orig_embeds_params = getattr(self, 'orig_embeds_params', None)
# if orig_embeds_params is not None:
# with torch.no_grad():
# self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# inputs_embeds = self.wte(input_ids)
vision_tower = getattr(self, 'vision_tower_2', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
im_patch_token = getattr(self.config, "im_patch_token", -1)
im_start_token = getattr(self.config, "im_start_token", -1)
im_end_token = getattr(self.config, "im_end_token", -1)
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
image_features = []
for image in images:
with torch.set_grad_enabled(True):
cnn_feature = vision_tower(image[1])
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1)
image_feature_final = cnn_feature
image_features.append(image_feature_final)
if type(images) is list:
image_features = [self.mm_projector(image_feature) for image_feature in image_features]
else:
# image_features = self.mm_projector(image_features)
raise NotImplementedError
# dummy_image_features = torch.zeros(1024, 1664, device=inputs_embeds.device, dtype=inputs_embeds.dtype).permute(0, 2, 1).reshape(dummy_image_features.shape[0], -1, 32, 32)
# VIT 1024; CNN:1024
dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = self.mm_projector(dummy_image_features)
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
# print(cur_input_ids)
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(varyOPTModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class varyOPTForCausalLM(OPTForCausalLM):
config_class = varyConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(OPTForCausalLM, self).__init__(config)
self.model = varyOPTModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, varyQwenModel):
# module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# print(input_ids)
# print(len(images))
outputs = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states).contiguous()
# logits
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
# add image patch token <image>
tokenizer.add_tokens("</s>", special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
tokenizer.add_tokens(DEFAULT_IMAGE_PATCH_TOKEN, special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
config.im_patch_token = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN)
config.use_im_start_end = True
# add image start token <im_start> and end token <im_end>
if config.use_im_start_end:
num_new_tokens = 2
tokenizer.add_tokens(DEFAULT_IM_START_TOKEN , special_tokens=True)
tokenizer.add_tokens(DEFAULT_IM_END_TOKEN , special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
config.im_start_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN)
config.im_end_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
# config.im_start_token, config.im_end_token = 151857, 151858
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
AutoConfig.register("vary", varyConfig)
AutoModelForCausalLM.register(varyConfig, varyOPTForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from vary.utils.constants import *
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel
from vary.model.llm.qwen.configuration_qwen import QWenConfig
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
class varyConfig(QWenConfig):
model_type = "vary"
class varyQwenModel(QWenModel):
config_class = varyConfig
def __init__(self, config: QWenConfig):
super(varyQwenModel, self).__init__(config)
# TODO download the clip-vit in huggingface
self.vision_tower = CLIPVisionModel.from_pretrained('/home/wanglch/projects/Vary/cache/vit-large-patch14')
self.vision_tower_high = build_sam_vit_b() # build_sam_vit_b(checkpoint = 'xxxx') for train
self.mm_projector = nn.Linear(1024, 2048)
self.mm_projector_vary = nn.Linear(1024, 2048)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
# 224*224
# TODO download the clip-vit in huggingface
image_processor = CLIPImageProcessor.from_pretrained('/home/wanglch/projects/Vary/cache/vit-large-patch14')
# 1024*1024
image_processor_high = train_transform
self.vision_tower = self.vision_tower.to(dtype=dtype, device=device)
self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
self.mm_projector = self.mm_projector.to(dtype=dtype, device=device)
self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor=image_processor,
image_processor_high=image_processor_high,
image_token_len=image_token_len,
)
def embed_tokens(self, x):
return self.wte(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
# orig_embeds_params = getattr(self, 'orig_embeds_params', None)
# if orig_embeds_params is not None:
# with torch.no_grad():
# self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# inputs_embeds = self.wte(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
vision_tower_high = getattr(self, 'vision_tower_high', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
# im_patch_token = getattr(self.config, "im_patch_token", -1)
# im_start_token = getattr(self.config, "im_start_token", -1)
# im_end_token = getattr(self.config, "im_end_token", -1)
# freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
im_patch_token = 151859
im_start_token = 151857
im_end_token = 151858
image_features_1 = []
image_features_2 = []
for image in images:
with torch.set_grad_enabled(False):
image_forward_out = vision_tower(image[0], output_hidden_states=True)
select_hidden_state = image_forward_out.hidden_states[vision_select_layer]
image_feature = select_hidden_state[:, 1:] # 256*1024
with torch.set_grad_enabled(False):
cnn_feature = vision_tower_high(image[1])
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
image_features_1.append(image_feature)
image_features_2.append(cnn_feature)
if type(images) is list:
image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1]
image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2]
image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)]
else:
raise NotImplementedError
# dummy_image_features = torch.zeros(256, 4096, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_1 = self.mm_projector(dummy_image_features_1)
dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1)
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
# if orig_embeds_params is not None:
# cur_new_input_embeds = torch.cat(
# (
# cur_input_embeds[:image_start_token_pos].detach(),
# cur_input_embeds[image_start_token_pos:image_start_token_pos+1],
# per_cur_image_features,
# cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2],
# cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()
# ),
# dim=0
# )
# else:
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(varyQwenModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class varyQwenForCausalLM(QWenLMHeadModel):
config_class = varyConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(QWenLMHeadModel, self).__init__(config)
self.transformer = varyQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, varyQwenModel):
# module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
# logits
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
# print(loss)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
self.resize_token_embeddings(len(tokenizer))
config.im_patch_token = 151859
config.use_im_start_end = True
# add image start token <im_start> and end token <im_end>
if config.use_im_start_end:
# num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
# config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
config.im_start_token, config.im_end_token = 151857, 151858
AutoConfig.register("vary", varyConfig)
AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
import torch
import torch.nn as nn
from typing import Type
import math
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
x = self.net_2(x)
x = self.net_3(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
if checkpoint is not None:
# with open(checkpoint, "rb") as f:
state_dict = torch.load(checkpoint)
image_encoder.load_state_dict(state_dict, strict=True)
# image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True)
print(checkpoint)
return image_encoder
if __name__ == '__main__':
x = torch.zeros(2, 3, 1024, 1024)
# x.permute(0, 3, 1, 2)
net = build_sam_vit_b(checkpoint ='')
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from vary.train.train import train
if __name__ == "__main__":
train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment