Commit c873301f authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from vary.data.base_dataset import BaseDataset
from vary.utils.constants import *
from vary.utils import conversation as conversation_lib
class ConversationDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
def __init__(self, datasets, tokenizer, multimodal_cfg):
super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
logging.warning("Formatting inputs into conversation type: mpt-fixed")
logging.warning("Loading data...")
list_data_dict = []
list_image_path = []
for name in datasets.split("+"):
# for name in vary_data_dict[name_all]:
dataset = CONVERSATION_DATA[name]
data_path = dataset['annotations']
data = json.load(open(data_path, "r"))
list_data_dict.extend(data)
image_path = dataset['images']
list_image_path.extend([image_path] * len(data))
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
assert len(list_data_dict) == len(list_image_path)
logging.warning(f"{len(list_data_dict)} conversations in total.")
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
self.im_patch_token = 151859
self.im_start_token = 151857
self.im_end_token = 151858
def multimodal_processor(self, sources):
for source in sources:
if self.multimodal_cfg['sep_image_conv_front']:
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
for sentence in source:
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
# if self.multimodal_cfg['use_im_start_end']:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def token_processor(self, sources):
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = self.tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
).input_ids
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
instruction_len = len(self.tokenizer(parts[0]).input_ids)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < self.tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
data = copy.deepcopy(self.list_data_dict[i])
if isinstance(data, dict):
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_path + image_file}.')
return self.__getitem__(0)
try:
image, image_1 = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
conversations = self.multimodal_processor([data["conversations"]])
else:
conversations = [data]
# align with fastchat & llava here, put the conversation into a list for tokenization
data_dict = self.token_processor(conversations)
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
if isinstance(data, dict) and 'image' in data:
data_dict['image'] = [image]
data_dict['image_high'] = [image_1]
else:
crop_size = self.multimodal_cfg['image_processor'].crop_size
data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from vary.model.plug.blip_process import BlipImageEvalProcessor
from vary.model.vision_encoder.sam import build_sam_vit_b
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = varyOPTForCausalLM.from_pretrained(model_name)
model.to(device='cuda', dtype=torch.bfloat16)
image_processor_high = test_transform
image_token_len = 256
prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
inputs = tokenizer([prompt])
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = '</s>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
do_sample=True,
num_beams = 1,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
# input_token_len = input_ids.shape[1]
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from vary.utils.conversation import conv_templates, SeparatorStyle
from vary.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from vary.model import *
from vary.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from vary.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from vary.model.plug.transforms import train_transform, test_transform
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = varyQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', trust_remote_code=True)
model.to(device='cuda', dtype=torch.bfloat16)
image_processor = CLIPImageProcessor.from_pretrained("/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/", torch_dtype=torch.float16)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
qs = 'Provide the ocr results of this image.'
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image = load_image(args.image_file)
image_1 = image.copy()
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=True,
num_beams = 1,
# temperature=0.2,
streamer=streamer,
max_new_tokens=2048,
stopping_criteria=[stopping_criteria]
)
# print(output_ids)
# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# # conv.messages[-1][-1] = outputs
# if outputs.endswith(stop_str):
# outputs = outputs[:-len(stop_str)]
# outputs = outputs.strip()
# print(outputs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
args = parser.parse_args()
eval_model(args)
from .vary_opt import varyOPTModel, varyOPTForCausalLM
# from .vary_qwen_vary import varyQwenModel, varyQwenForCausalLM, varyConfig
from .vary_toy_qwen1_8 import varyQwenModel, varyQwenForCausalLM, varyConfig
This diff is collapsed.
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from vary.train.train import train
if __name__ == "__main__":
train()
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
# from vary.train.train import train
from vary.train.train_lora import train
if __name__ == "__main__":
train()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment