Commit c873301f authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib
class ConversationDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
def __init__(self, datasets, tokenizer, multimodal_cfg):
super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
logging.warning("Formatting inputs into conversation type: mpt-fixed")
logging.warning("Loading data...")
list_data_dict = []
list_image_path = []
for name in datasets.split("+"):
# for name in vary_data_dict[name_all]:
dataset = CONVERSATION_DATA[name]
data_path = dataset['annotations']
data = json.load(open(data_path, "r"))
list_data_dict.extend(data)
image_path = dataset['images']
list_image_path.extend([image_path] * len(data))
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
assert len(list_data_dict) == len(list_image_path)
logging.warning(f"{len(list_data_dict)} conversations in total.")
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
self.im_patch_token = 151859
self.im_start_token = 151857
self.im_end_token = 151858
def multimodal_processor(self, sources):
for source in sources:
if self.multimodal_cfg['sep_image_conv_front']:
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
for sentence in source:
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
# if self.multimodal_cfg['use_im_start_end']:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def token_processor(self, sources):
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = self.tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
).input_ids
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
instruction_len = len(self.tokenizer(parts[0]).input_ids)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < self.tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
data = copy.deepcopy(self.list_data_dict[i])
if isinstance(data, dict):
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_path + image_file}.')
return self.__getitem__(0)
try:
image, image_1 = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
conversations = self.multimodal_processor([data["conversations"]])
else:
conversations = [data]
# align with fastchat & llava here, put the conversation into a list for tokenization
data_dict = self.token_processor(conversations)
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
if isinstance(data, dict) and 'image' in data:
data_dict['image'] = [image]
data_dict['image_high'] = [image_1]
else:
crop_size = self.multimodal_cfg['image_processor'].crop_size
data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = varyOPTForCausalLM.from_pretrained(model_name)
model.to(device='cuda', dtype=torch.bfloat16)
image_processor_high = test_transform
image_token_len = 256
prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
inputs = tokenizer([prompt])
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
do_sample=True,
num_beams = 1,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
# input_token_len = input_ids.shape[1]
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from vary.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = varyQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', trust_remote_code=True)
model.to(device='cuda', dtype=torch.bfloat16)
image_processor = CLIPImageProcessor.from_pretrained("/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/", torch_dtype=torch.float16)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
qs = 'Provide the ocr results of this image.'
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=True,
num_beams = 1,
# temperature=0.2,
streamer=streamer,
max_new_tokens=2048,
stopping_criteria=[stopping_criteria]
)
# print(output_ids)
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# # conv.messages[-1][-1] = outputs
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
from .vary_opt import varyOPTModel, varyOPTForCausalLM
# from .vary_qwen_vary import varyQwenModel, varyQwenForCausalLM, varyConfig
from .vary_toy_qwen1_8 import varyQwenModel, varyQwenForCausalLM, varyConfig
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.models.opt.configuration_opt import OPTConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
_CONFIG_FOR_DOC = "OPTConfig"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
# Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
# SequenceClassification docstring
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
_SEQ_CLASS_EXPECTED_LOSS = 1.71
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
# QuestionAnswering docstring
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 7.41
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
"facebook/opt-13b",
"facebook/opt-30b",
# See all OPT models at https://huggingface.co/models?filter=opt
]
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
OPT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`OPTConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
OPT_START_DOCSTRING,
)
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (OPTDecoder)):
module.gradient_checkpointing = value
OPT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class OPTDecoder(OPTPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
Args:
config: OPTConfig
"""
def __init__(self, config: OPTConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.word_embed_proj_dim, self.padding_idx
)
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size
)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(
config.hidden_size, config.word_embed_proj_dim, bias=False
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(
config.word_embed_proj_dim, config.hidden_size, bias=False
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
past_key_values_length=past_key_values_length,
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if query_embeds is not None:
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@add_start_docstrings(
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
OPT_START_DOCSTRING,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig):
super().__init__(config)
self.decoder = OPTDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
query_embeds=query_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(
config.word_embed_proj_dim, config.vocab_size, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
query_embeds=query_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
loss = loss_fct(
shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
)
if reduction == "none":
loss = loss.view(shift_logits.size(0), -1).sum(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids=None,
query_embeds=None,
past=None,
attention_mask=None,
use_cache=None,
**kwargs,
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
if input_ids is not None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
query_embeds = None
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids,
"query_embeds": query_embeds,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
HistoryType,
make_context,
decode_tokens,
get_stop_words_ids,
StopWordsLogitsProcessor,
)
# from .visual import VisionTransformer
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "qwen"
_CONFIG_FOR_DOC = "QWenConfig"
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
"""
_SENTINEL = object()
_ERROR_STREAM_IN_CHAT = """\
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
"""
apply_rotary_emb_func = None
rms_norm = None
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.scale_attn_weights = True
self.projection_size = config.kv_channels * config.num_attention_heads
assert self.projection_size % config.num_attention_heads == 0
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(
config.hidden_size, self.projection_size, bias=not config.no_bias
)
self.is_fp32 = not (config.bf16 or config.fp16)
self.bf16 = config.bf16
if config.rotary_pct == 1.0:
self.rotary_ndims = None
else:
assert config.rotary_pct < 1
self.rotary_ndims = int(
self.hidden_size_per_attention_head * config.rotary_pct
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else self.hidden_size_per_attention_head
)
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [
math.log(i, self.seq_length) if i > self.seq_length else 1
for i in range(1, 32768)
]
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self._ntk_cached = 1.0
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask = self.bias[
# :, :, key_length - query_length : key_length, :key_length
# ]
# mask_value = torch.finfo(attn_weights.dtype).min
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
# attn_weights.device
# )
# attn_weights = torch.where(
# causal_mask, attn_weights.to(attn_weights.dtype), mask_value
# )
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
def _upcast_and_reordered_attn(
self, query, key, value, attention_mask=None, head_mask=None
):
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
attn_weights = torch.empty(
bsz * num_heads,
q_seq_len,
k_seq_len,
dtype=torch.float32,
device=query.device,
)
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
-1, dk, k_seq_len
)
attn_weights = torch.baddbmm(
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[
:, :, key_length - query_length : key_length, :key_length
]
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
attn_weights.device
)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if attn_weights.dtype != torch.float32:
raise RuntimeError(
"Error with upcasting, attn_weights does not have dtype torch.float32"
)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
kv_seq_len = hidden_states.size()[1]
if layer_past:
# layer past[0] shape: bs * seq_len * head_num * dim
kv_seq_len += layer_past[0].shape[1]
if (
self.use_dynamic_ntk
and kv_seq_len == hidden_states.size()[1]
and not self.training
):
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
self._ntk_cached = ntk_alpha
else:
ntk_alpha = self._ntk_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
hidden_states.device
)
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = (rotary_pos_emb,) * 2
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
cur_len = query.shape[1]
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
else:
present = None
if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weight,)
return outputs
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w2 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
ff_dim_in = config.intermediate_size // 2
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def forward(self, hidden_states):
a1 = self.w1(hidden_states)
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
output = self.c_proj(intermediate_parallel)
return output
class QWenBlock(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.bf16 = config.bf16
self.ln_1 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config)
self.ln_2 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn(
layernorm_output,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
residual = hidden_states
layernorm_input = attn_output + residual
layernorm_output = self.ln_2(layernorm_input)
residual = layernorm_input
mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs
class QWenPreTrainedModel(PreTrainedModel):
config_class = QWenConfig
base_model_prefix = "transformer"
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "c_proj.weight":
p.data.normal_(
mean=0.0,
std=(
self.config.initializer_range
/ math.sqrt(2 * self.config.num_hidden_layers)
),
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, QWenModel):
module.gradient_checkpointing = value
class QWenModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.gradient_checkpointing = False
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.emb_dropout_prob)
self.h = nn.ModuleList(
[
QWenBlock(
config,
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
)
# self.visual = VisionTransformer(**config.visual)
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
# bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
# eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
# assert (bos_pos[0] == eos_pos[0]).all()
# img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
# images = []
# for i, a, b in img_pos:
# image = input_ids[i][a + 1 : b - 1].tolist()
# image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
# images.append(bytes(image).decode('utf-8'))
# images = self.visual.encode(images)
# assert images.shape[0] == len(images)
# else:
# images = None
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_length
)
hidden_states = inputs_embeds
hidden_states = self.drop(hidden_states)
# if images is not None:
# for idx, (i, a, b) in enumerate(img_pos):
# hidden_states[i][a + 1 : b] = images[idx]
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class QWenLMHeadModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past_key_values
)
def chat(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
append_history: bool = True,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = self.generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=self.generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
self.generation_config.chat_format, tokenizer
))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids = stop_words_ids,
return_dict_in_generate = False,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=self.generation_config.chat_format,
verbose=False,
errors='replace'
)
if append_history:
history.append((query, response))
return response, history
def chat_stream(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
**kwargs,
) -> Generator[str, Any, None]:
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = self.generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=self.generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
self.generation_config.chat_format, tokenizer
))
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=self.generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
def stream_generator():
outputs = []
for token in self.generate_stream(
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs):
outputs.append(token.item())
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
return stream_generator()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
# Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
if stop_words_ids is None and generation_config is not None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is None:
stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=self.generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
return super().generate(
inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs,
)
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
if importlib.util.find_spec("einops") is None:
raise RuntimeError("einops is required for Rotary Embedding")
self._rotary_pos_emb_cache = None
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0
def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
seqlen = max_seq_len + offset
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
/ self.dim
)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
from einops import rearrange
self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
if apply_rotary_emb_func is not None and t.is_cuda:
t_ = t.float()
freqs = freqs.squeeze(0).squeeze(1)
cos = freqs[:, : freqs.shape[-1] // 2].cos()
sin = freqs[:, : freqs.shape[-1] // 2].sin()
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output
else:
rot_dim = freqs.shape[-1]
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
if rms_norm is not None and x.is_cuda:
return rms_norm(x, self.weight, self.eps)
else:
output = self._norm(x.float()).type_as(x)
return output * self.weight
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
return batch
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().to(context_tokens.device)
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_id,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
return tokens, attention_mask, position_ids
def get_stop_words_ids(chat_format, tokenizer):
if chat_format == "raw":
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
elif chat_format == "chatml":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set(tokenizer.IMAGE_ST)
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return raw_text, context_tokens
def _decode_default(
tokens: List[int],
*,
stop_words: List[str],
eod_words: List[str],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace',
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate: ", trim_decode_tokens)
end_reason = f"Gen length {len(tokens)}"
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
for eod_word in eod_words:
if eod_word in trim_decode_tokens:
end_reason = f"Gen {eod_word!r}"
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nEnd Reason:", end_reason)
print("\nGenerate: ", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def _decode_chatml(
tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace'
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in eod_token_ids:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
print("\nRaw Generate:", trim_decode_tokens)
print("\nEnd Reason:", end_reason)
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nGenerate:", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str="replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
if chat_format == "chatml":
return _decode_chatml(
tokens,
stop_words=[],
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
context_length=context_length,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
elif chat_format == "raw":
return _decode_default(
tokens,
stop_words=["<|endoftext|>"],
eod_words=["<|endoftext|>"],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
class StopWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
Args:
stop_words_ids (:obj:`List[List[int]]`):
List of list of token ids of stop ids. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, self.eos_token_id] = float(2**15)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
stopped_samples = []
for prev_input_ids_slice in prev_input_ids:
match = False
for stop_token_seq in self.stop_words_ids:
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
return stopped_samples
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Tokenization classes for QWen."""
import base64
import logging
import os
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union
import tiktoken
from transformers import PreTrainedTokenizer, AddedToken
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
SPECIAL_TOKENS = (
ENDOFTEXT,
IMSTART,
IMEND,
) + EXTRAS
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
**kwargs,
):
super().__init__(**kwargs)
self.errors = errors # how to handle errors in decoding
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
self.special_tokens = {
token: index
for index, token in enumerate(
SPECIAL_TOKENS, start=len(self.mergeable_ranks)
)
}
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {
v: k for k, v in self.mergeable_ranks.items()
} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})
self.tokenizer = enc # type: tiktoken.Encoding
self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks
def convert_tokens_to_ids(
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
if not special_tokens and new_tokens:
raise ValueError('Adding regular tokens is not supported')
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS:
raise ValueError('Adding unknown special tokens is not supported')
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).
Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.
Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.
Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)
# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.decoder[t])
return tokens
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
@property
def vocab_size(self):
return self.tokenizer.n_vocab
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")
def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import cv2
import numpy as np
import torch
# from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
# @classmethod
# def from_config(cls, cfg=None):
# return cls()
# def build(self, **kwargs):
# cfg = OmegaConf.create(kwargs)
# return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
"""
same output as PIL.ImageOps.autocontrast
"""
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
"""
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
"""
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
"""
like PIL, rotate by degree, not radians
"""
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
"""
same output as PIL.ImageOps.posterize
"""
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
"""
same output as PIL.ImageEnhance.Color
"""
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = np.float32(
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
) * factor + np.float32([[0.114], [0.587], [0.299]])
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = (
np.array([(el - mean) * factor + mean for el in range(256)])
.clip(0, 255)
.astype(np.uint8)
)
out = table[img]
return out
def brightness_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
"""
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
"""
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def posterize_func(img, bits):
"""
same output as PIL.ImageOps.posterize
"""
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level,)
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level,)
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
"Identity": identity_func,
"AutoContrast": autocontrast_func,
"Equalize": equalize_func,
"Rotate": rotate_func,
"Solarize": solarize_func,
"Color": color_func,
"Contrast": contrast_func,
"Brightness": brightness_func,
"Sharpness": sharpness_func,
"ShearX": shear_x_func,
"TranslateX": translate_x_func,
"TranslateY": translate_y_func,
"Posterize": posterize_func,
"ShearY": shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
"Identity": none_level_to_args,
"AutoContrast": none_level_to_args,
"Equalize": none_level_to_args,
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
"Solarize": solarize_level_to_args(MAX_LEVEL),
"Color": enhance_level_to_args(MAX_LEVEL),
"Contrast": enhance_level_to_args(MAX_LEVEL),
"Brightness": enhance_level_to_args(MAX_LEVEL),
"Sharpness": enhance_level_to_args(MAX_LEVEL),
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"Posterize": posterize_level_to_args(MAX_LEVEL),
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
class VideoRandomAugment(object):
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
self.N = N
self.M = M
self.p = p
self.tensor_in_tensor_out = tensor_in_tensor_out
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
return [(op, self.M) for op in sampled_ops]
def __call__(self, frames):
assert (
frames.shape[-1] == 3
), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
if self.tensor_in_tensor_out:
frames = frames.numpy().astype(np.uint8)
num_frames = frames.shape[0]
ops = num_frames * [self.get_random_ops()]
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
frames = torch.stack(
list(map(self._aug, frames, ops, apply_or_not)), dim=0
).float()
return frames
def _aug(self, img, ops, apply_or_not):
for i, (name, level) in enumerate(ops):
if not apply_or_not[i]:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return torch.from_numpy(img)
# if __name__ == "__main__":
# a = RandomAugment()
# img = np.random.randn(32, 32, 3)
# a(img)
class BlipImageTrainProcessor(BlipImageBaseProcessor):
def __init__(
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
# transforms.RandomHorizontalFlip(),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
# "AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
# "ShearX",
# "ShearY",
# "TranslateX",
# "TranslateY",
# "Rotate",
],
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
# if __name__ == "__main__":
# a = BlipImageTrainProcessor(image_size=1024)
# # img = np.random.randn(1024, 1024, 3)
# # x = torch.zeros(1024, 1024, 3)
# x = Image.open("/data/codes/vary-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB")
# print(x.size)
# y = a(x)
# print(y.size())
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
# Implements image augmentation
import albumentations as alb
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def alb_wrapper(transform):
def f(im):
return transform(image=np.asarray(im))["image"]
return f
class Erosion(alb.ImageOnlyTransform):
"""
Apply erosion operation to an image.
Erosion is a morphological operation that shrinks the white regions in a binary image.
Args:
scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
If an integer is provided, a square kernel of that size will be used.
If a tuple or list is provided, it should contain two integers representing the minimum
and maximum sizes for the erosion kernel.
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
p (float, optional): The probability of applying this transformation. Default is 0.5.
Returns:
numpy.ndarray: The transformed image.
"""
def __init__(self, scale, always_apply=False, p=0.5):
super().__init__(always_apply=always_apply, p=p)
if type(scale) is tuple or type(scale) is list:
assert len(scale) == 2
self.scale = scale
else:
self.scale = (scale, scale)
def apply(self, img, **params):
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
)
img = cv2.erode(img, kernel, iterations=1)
return img
class Dilation(alb.ImageOnlyTransform):
"""
Apply dilation operation to an image.
Dilation is a morphological operation that expands the white regions in a binary image.
Args:
scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
If an integer is provided, a square kernel of that size will be used.
If a tuple or list is provided, it should contain two integers representing the minimum
and maximum sizes for the dilation kernel.
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
p (float, optional): The probability of applying this transformation. Default is 0.5.
Returns:
numpy.ndarray: The transformed image.
"""
def __init__(self, scale, always_apply=False, p=0.5):
super().__init__(always_apply=always_apply, p=p)
if type(scale) is tuple or type(scale) is list:
assert len(scale) == 2
self.scale = scale
else:
self.scale = (scale, scale)
def apply(self, img, **params):
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
)
img = cv2.dilate(img, kernel, iterations=1)
return img
class Bitmap(alb.ImageOnlyTransform):
"""
Apply a bitmap-style transformation to an image.
This transformation replaces all pixel values below a certain threshold with a specified value.
Args:
value (int, optional): The value to replace pixels below the threshold with. Default is 0.
lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
always_apply (bool, optional): Whether to always apply this transformation. Default is False.
p (float, optional): The probability of applying this transformation. Default is 0.5.
Returns:
numpy.ndarray: The transformed image.
"""
def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
super().__init__(always_apply=always_apply, p=p)
self.lower = lower
self.value = value
def apply(self, img, **params):
img = img.copy()
img[img < self.lower] = self.value
return img
train_transform = alb_wrapper(
alb.Compose(
[
Bitmap(p=0),
alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.02),
alb.Affine(shear={"x": (0, 3), "y": (-3, 0)}, cval=(255, 255, 255), p=0.03),
alb.ShiftScaleRotate(
shift_limit_x=(0, 0.04),
shift_limit_y=(0, 0.03),
scale_limit=(-0.15, 0.03),
rotate_limit=2,
border_mode=0,
interpolation=2,
value=(255, 255, 255),
p=0.03,
),
alb.GridDistortion(
distort_limit=0.05,
border_mode=0,
interpolation=2,
value=(255, 255, 255),
p=0.04,
),
alb.Compose(
[
alb.Affine(
translate_px=(0, 5), always_apply=True, cval=(255, 255, 255)
),
alb.ElasticTransform(
p=1,
alpha=50,
sigma=120 * 0.1,
alpha_affine=120 * 0.01,
border_mode=0,
value=(255, 255, 255),
),
],
p=0.04,
),
alb.RandomBrightnessContrast(0.1, 0.1, True, p=0.03),
alb.ImageCompression(95, p=0.07),
alb.GaussNoise(20, p=0.08),
alb.GaussianBlur((3, 3), p=0.03),
alb.Resize(1024, 1024),
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
)
)
test_transform = alb_wrapper(
alb.Compose(
[
alb.Resize(1024, 1024),
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
)
)
# if __name__ == '__main__':
# from PIL import Image
# image = Image.open('/data/hypertext/ucaswei/codes_new/show/49.jpg').convert('RGB')
# # image = np.array(image)
# for i in range(100):
# image1 = train_transform(image)
# image1 = Image.fromarray(np.uint8(image1))
# image1.save('/data/hypertext/ucaswei/codes_new/aug/' + str(i) + '.jpg')
# mm_projector_1 = nn.Linear(1024, 256)
# x = torch.zeros(2, 3, 1024, 1024)
# with torch.no_grad():
# y = model(x)
# print(y.shape)
# y = mm_projector_1(y.permute(0,2,1))
# print(y.shape)
# print(y.permute(0,2,1).shape)
\ No newline at end of file
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM, \
CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from vary.utils.constants import *
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from transformers import OPTConfig, OPTModel, OPTForCausalLM
from vary.model.plug.transforms import train_transform, test_transform
class varyConfig(OPTConfig):
model_type = "vary"
class varyOPTModel(OPTModel):
config_class = varyConfig
def __init__(self, config: OPTConfig):
super(varyOPTModel, self).__init__(config)
self.vision_tower = build_sam_vit_b()
self.mm_projector = nn.Linear(1024, 768)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
# 224*224
# image_processor do not used in opt
image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14')
# 1024*1024
image_processor_high = train_transform
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor=image_processor,
image_processor_high=image_processor_high,
image_token_len=image_token_len,
)
def embed_tokens(self, x):
return self.get_input_embeddings()(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
# orig_embeds_params = getattr(self, 'orig_embeds_params', None)
# if orig_embeds_params is not None:
# with torch.no_grad():
# self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# inputs_embeds = self.wte(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
im_patch_token = getattr(self.config, "im_patch_token", -1)
im_start_token = getattr(self.config, "im_start_token", -1)
im_end_token = getattr(self.config, "im_end_token", -1)
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
image_features = []
for image in images:
with torch.set_grad_enabled(True):
cnn_feature = vision_tower(image[1])
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1)
image_feature_final = cnn_feature
image_features.append(image_feature_final)
if type(images) is list:
image_features = [self.mm_projector(image_feature) for image_feature in image_features]
else:
# image_features = self.mm_projector(image_features)
raise NotImplementedError
# dummy_image_features = torch.zeros(1024, 1664, device=inputs_embeds.device, dtype=inputs_embeds.dtype).permute(0, 2, 1).reshape(dummy_image_features.shape[0], -1, 32, 32)
# VIT 1024; CNN:1024
dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = self.mm_projector(dummy_image_features)
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
# print(cur_input_ids)
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(varyOPTModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class varyOPTForCausalLM(OPTForCausalLM):
config_class = varyConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(OPTForCausalLM, self).__init__(config)
self.model = varyOPTModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, varyQwenModel):
# module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# print(input_ids)
# print(len(images))
outputs = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states).contiguous()
# logits
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
# add image patch token <image>
tokenizer.add_tokens("</s>", special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
tokenizer.add_tokens(DEFAULT_IMAGE_PATCH_TOKEN, special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
config.im_patch_token = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN)
config.use_im_start_end = True
# add image start token <im_start> and end token <im_end>
if config.use_im_start_end:
num_new_tokens = 2
tokenizer.add_tokens(DEFAULT_IM_START_TOKEN , special_tokens=True)
tokenizer.add_tokens(DEFAULT_IM_END_TOKEN , special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
config.im_start_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN)
config.im_end_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
# config.im_start_token, config.im_end_token = 151857, 151858
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
AutoConfig.register("vary", varyConfig)
AutoModelForCausalLM.register(varyConfig, varyOPTForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from vary.utils.constants import *
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel
from vary.model.llm.qwen.configuration_qwen import QWenConfig
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
class varyConfig(QWenConfig):
model_type = "vary"
class varyQwenModel(QWenModel):
config_class = varyConfig
def __init__(self, config: QWenConfig):
super(varyQwenModel, self).__init__(config)
# TODO download the clip-vit in huggingface
self.vision_tower = CLIPVisionModel.from_pretrained('/cache/vit-large-patch14/')
self.vision_tower_high = build_sam_vit_b() # build_sam_vit_b(checkpoint = 'xxxx') for train
self.mm_projector = nn.Linear(1024, 2048)
self.mm_projector_vary = nn.Linear(1024, 2048)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
# 224*224
# TODO download the clip-vit in huggingface
image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14/')
# 1024*1024
image_processor_high = train_transform
self.vision_tower = self.vision_tower.to(dtype=dtype, device=device)
self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
self.mm_projector = self.mm_projector.to(dtype=dtype, device=device)
self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor=image_processor,
image_processor_high=image_processor_high,
image_token_len=image_token_len,
)
def embed_tokens(self, x):
return self.wte(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
# orig_embeds_params = getattr(self, 'orig_embeds_params', None)
# if orig_embeds_params is not None:
# with torch.no_grad():
# self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# inputs_embeds = self.wte(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
vision_tower_high = getattr(self, 'vision_tower_high', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
# im_patch_token = getattr(self.config, "im_patch_token", -1)
# im_start_token = getattr(self.config, "im_start_token", -1)
# im_end_token = getattr(self.config, "im_end_token", -1)
# freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
im_patch_token = 151859
im_start_token = 151857
im_end_token = 151858
image_features_1 = []
image_features_2 = []
for image in images:
with torch.set_grad_enabled(False):
image_forward_out = vision_tower(image[0], output_hidden_states=True)
select_hidden_state = image_forward_out.hidden_states[vision_select_layer]
image_feature = select_hidden_state[:, 1:] # 256*1024
with torch.set_grad_enabled(False):
cnn_feature = vision_tower_high(image[1])
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
image_features_1.append(image_feature)
image_features_2.append(cnn_feature)
if type(images) is list:
image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1]
image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2]
image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)]
else:
raise NotImplementedError
# dummy_image_features = torch.zeros(256, 4096, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_1 = self.mm_projector(dummy_image_features_1)
dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1)
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
# if orig_embeds_params is not None:
# cur_new_input_embeds = torch.cat(
# (
# cur_input_embeds[:image_start_token_pos].detach(),
# cur_input_embeds[image_start_token_pos:image_start_token_pos+1],
# per_cur_image_features,
# cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2],
# cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()
# ),
# dim=0
# )
# else:
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(varyQwenModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class varyQwenForCausalLM(QWenLMHeadModel):
config_class = varyConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(QWenLMHeadModel, self).__init__(config)
self.transformer = varyQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, varyQwenModel):
# module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
# logits
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
# print(loss)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
self.resize_token_embeddings(len(tokenizer))
config.im_patch_token = 151859
config.use_im_start_end = True
# add image start token <im_start> and end token <im_end>
if config.use_im_start_end:
# num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
# config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
config.im_start_token, config.im_end_token = 151857, 151858
AutoConfig.register("vary", varyConfig)
AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM)
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from vary.utils.constants import *
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel
from vary.model.llm.qwen.configuration_qwen import QWenConfig
from vary.model.vision_encoder.sam import build_sam_vit_b
class varyConfig(QWenConfig):
model_type = "vary"
class varyQwenModel(QWenModel):
config_class = varyConfig
def __init__(self, config: QWenConfig):
super(varyQwenModel, self).__init__(config)
self.vision_tower = CLIPVisionModel.from_pretrained('/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/')
self.vision_tower_high = build_sam_vit_b()
self.mm_projector = nn.Linear(1024, 1024)
self.mm_projector_vary = nn.Linear(1024, 1024)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
# 224*224
image_processor = CLIPImageProcessor.from_pretrained('/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/')
# 1024*1024
image_processor_high = BlipImageEvalProcessor(image_size=1024)
self.vision_tower = self.vision_tower.to(dtype=dtype, device=device)
self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
self.mm_projector = self.mm_projector.to(dtype=dtype, device=device)
self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor=image_processor,
image_processor_high=image_processor_high,
image_token_len=image_token_len,
# vision_config=vision_config
)
def embed_tokens(self, x):
return self.wte(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
vision_tower_high = getattr(self, 'vision_tower_high', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
im_patch_token = getattr(self.config, "im_patch_token", -1)
im_start_token = getattr(self.config, "im_start_token", -1)
im_end_token = getattr(self.config, "im_end_token", -1)
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
im_patch_token = 151859
im_start_token = 151857
im_end_token = 151858
image_features = []
image_features_1 = []
image_features_2 = []
for image in images:
with torch.set_grad_enabled(False):
image_forward_out = vision_tower(image[0], output_hidden_states=True)
select_hidden_state = image_forward_out.hidden_states[vision_select_layer]
image_feature = select_hidden_state[:, 1:] # 256*1024
with torch.set_grad_enabled(False):
cnn_feature = vision_tower_high(image[1])
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
image_features_1.append(image_feature)
image_features_2.append(cnn_feature)
if type(images) is list:
image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1]
image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2]
image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)]
else:
raise NotImplementedError
dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features_1 = self.mm_projector(dummy_image_features_1)
dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1)
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(varyQwenModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class varyQwenForCausalLM(QWenLMHeadModel):
config_class = varyConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(QWenLMHeadModel, self).__init__(config)
self.transformer = varyQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
# logits
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
self.resize_token_embeddings(len(tokenizer))
config.im_patch_token = 151859
config.use_im_start_end = True
if config.use_im_start_end:
self.resize_token_embeddings(len(tokenizer))
config.im_start_token, config.im_end_token = 151857, 151858
AutoConfig.register("vary", varyConfig)
AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
import torch
import torch.nn as nn
from typing import Type
import math
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
x = self.net_2(x)
x = self.net_3(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
if checkpoint is not None:
# with open(checkpoint, "rb") as f:
state_dict = torch.load(checkpoint)
image_encoder.load_state_dict(state_dict, strict=True)
# image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True)
print(checkpoint)
return image_encoder
if __name__ == '__main__':
x = torch.zeros(2, 3, 1024, 1024)
# x.permute(0, 3, 1, 2)
net = build_sam_vit_b(checkpoint ='')
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from vary.train.train import train
if __name__ == "__main__":
train()
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
# from vary.train.train import train
from vary.train.train_lora import train
if __name__ == "__main__":
train()
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import torch
import transformers
from vary.train.trainer_vit_fixlr import varyTrainer
from vary.model import *
from vary.data import make_supervised_data_module
from vary.utils.arguments import *
from vary.utils.constants import *
from vary.model.vision_encoder.sam import build_sam_vit_b
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=False, padding_side="right", model_max_length=training_args.model_max_length)
model = varyOPTForCausalLM.from_pretrained(model_args.model_name_or_path)
dtype = torch.float32
if training_args.fp16:
dtype = torch.float16
if training_args.bf16:
dtype = torch.bfloat16
vision_tower_dict = model.get_model().initialize_vision_modules(
vision_tower=model_args.vision_tower,
pretrained_stage1_model=model_args.pretrained_stage1_model,
freeze_vision_tower=model_args.freeze_vision_tower,
use_im_start_end=model_args.use_im_start_end,
vision_select_layer=model_args.vision_select_layer,
dtype=dtype,
device=training_args.device
)
model.initialize_vision_tokenizer(
tokenizer=tokenizer,
freeze_lm_model=model_args.freeze_lm_model,
pretrained_stage1_model=model_args.pretrained_stage1_model,
device=training_args.device,
)
model.to(dtype=dtype, device=training_args.device)
data_args.image_token_len = 256
data_args.image_processor = vision_tower_dict['image_processor']
data_args.image_processor_high = vision_tower_dict['image_processor_high']
data_args.use_im_start_end = model_args.use_im_start_end
# mixed relation, to be fixed
if model_args.freeze_lm_model:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
for p in model.get_input_embeddings().parameters():
p.requires_grad = True
if not model_args.freeze_vision_tower:
model.get_model().vision_tower.requires_grad_(True)
params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")
# params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
# if len(params_no_grad) > 0:
# if training_args.fsdp is not None and len(training_args.fsdp) > 0:
# if len(params_no_grad) < 10:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
# else:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
# print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
# print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
# from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# def patch_FSDP_use_orig_params(func):
# def wrap_func(*args, **kwargs):
# use_orig_params = kwargs.pop('use_orig_params', True)
# return func(*args, **kwargs, use_orig_params=use_orig_params)
# return wrap_func
# FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
# interleave = True
data_module = make_supervised_data_module(
interleave=training_args.interleave,
with_box=training_args.with_box,
tokenizer=tokenizer,
data_args=data_args
)
trainer = varyTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
trainer._safe_save(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment