Commit 34164470 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# Copyright (c) Facebook, Inc. and its affiliates.
import re
from tqdm import tqdm
class EvalAIAnswerProcessor:
"""
Processes an answer similar to Eval AI
copied from
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
"""
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
NUMBER_MAP = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
ARTICLES = ["a", "an", "the"]
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
PUNCTUATIONS = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def __init__(self, *args, **kwargs):
pass
def word_tokenize(self, word):
word = word.lower()
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
return word.strip()
def process_punctuation(self, in_text):
out_text = in_text
for p in self.PUNCTUATIONS:
if (p + " " in in_text or " " + p in in_text) or (
re.search(self.COMMA_STRIP, in_text) is not None
):
out_text = out_text.replace(p, "")
else:
out_text = out_text.replace(p, " ")
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
return out_text
def process_digit_article(self, in_text):
out_text = []
temp_text = in_text.lower().split()
for word in temp_text:
word = self.NUMBER_MAP.setdefault(word, word)
if word not in self.ARTICLES:
out_text.append(word)
else:
pass
for word_id, word in enumerate(out_text):
if word in self.CONTRACTIONS:
out_text[word_id] = self.CONTRACTIONS[word]
out_text = " ".join(out_text)
return out_text
def __call__(self, item):
item = self.word_tokenize(item)
item = item.replace("\n", " ").replace("\t", " ").strip()
item = self.process_punctuation(item)
item = self.process_digit_article(item)
return item
class TextVQAAccuracyEvaluator:
def __init__(self):
self.answer_processor = EvalAIAnswerProcessor()
def _compute_answer_scores(self, raw_answers):
"""
compute the accuracy (soft score) of human answers
"""
answers = [self.answer_processor(a) for a in raw_answers]
assert len(answers) == 10
gt_answers = list(enumerate(answers))
unique_answers = set(answers)
unique_answer_scores = {}
for unique_answer in unique_answers:
accs = []
for gt_answer in gt_answers:
other_answers = [item for item in gt_answers if item != gt_answer]
matching_answers = [
item for item in other_answers if item[1] == unique_answer
]
acc = min(1, float(len(matching_answers)) / 3)
accs.append(acc)
unique_answer_scores[unique_answer] = sum(accs) / len(accs)
return unique_answer_scores
def eval_pred_list(self, pred_list):
pred_scores = []
for entry in tqdm(pred_list):
pred_answer = self.answer_processor(entry["pred_answer"])
unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
score = unique_answer_scores.get(pred_answer, 0.0)
pred_scores.append(score)
accuracy = sum(pred_scores) / len(pred_scores)
return accuracy
class STVQAAccuracyEvaluator:
def __init__(self):
self.answer_processor = EvalAIAnswerProcessor()
def eval_pred_list(self, pred_list):
pred_scores = []
for entry in pred_list:
pred_answer = self.answer_processor(entry["pred_answer"])
gts = [self.answer_processor(a) for a in entry["gt_answers"]]
score = 1.0 if pred_answer in gts else 0.0
pred_scores.append(score)
accuracy = sum(pred_scores) / len(pred_scores)
return accuracy
class STVQAANLSEvaluator:
def __init__(self):
import editdistance # install with `pip install editdistance`
self.get_edit_distance = editdistance.eval
def get_anls(self, s1, s2):
s1 = s1.lower().strip()
s2 = s2.lower().strip()
iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
anls = iou if iou >= 0.5 else 0.0
return anls
def eval_pred_list(self, pred_list):
pred_scores = []
for entry in pred_list:
anls = max(
self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
)
pred_scores.append(anls)
accuracy = sum(pred_scores) / len(pred_scores)
return accuracy
class TextCapsBleu4Evaluator:
def __init__(self):
# The following script requires Java 1.8.0 and pycocotools installed.
# The pycocoevalcap can be installed with pip as
# pip install git+https://github.com/ronghanghu/coco-caption.git@python23
# Original pycocoevalcap code is at https://github.com/tylin/coco-caption
# but has no python3 support yet.
try:
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
except ModuleNotFoundError:
print(
"Please install pycocoevalcap module using "
"pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
)
raise
self.tokenizer = PTBTokenizer()
self.scorer = Bleu(4)
def eval_pred_list(self, pred_list):
# Create reference and hypotheses captions.
gts = {}
res = {}
for idx, entry in enumerate(pred_list):
gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
res[idx] = [{"caption": entry["pred_answer"]}]
gts = self.tokenizer.tokenize(gts)
res = self.tokenizer.tokenize(res)
score, _ = self.scorer.compute_score(gts, res)
bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
return bleu4
import os
import json
import math
import torch
import argparse
import shortuuid
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from mobilevlm.conversation import conv_templates, SeparatorStyle
from mobilevlm.model.mobilevlm import load_pretrained_model
from mobilevlm.utils import disable_torch_init, tokenizer_image_token, process_images, get_model_name_from_path
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
# Custom dataset class
class CustomDataset(Dataset):
def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
self.questions = questions
self.image_folder = image_folder
self.tokenizer = tokenizer
self.image_processor = image_processor
self.model_config = model_config
def __getitem__(self, index):
line = self.questions[index]
image_file = line["image"]
qs = line["text"]
if self.model_config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
image_tensor = process_images([image], self.image_processor, self.model_config)[0]
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
return input_ids, image_tensor
def __len__(self):
return len(self.questions)
# DataLoader
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
assert batch_size == 1, "batch_size must be 1"
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
return data_loader
def eval_model(args):
# Model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path)
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
args.conv_mode = args.conv_mode + '_mmtag'
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
idx = line["question_id"]
cur_prompt = line["text"]
stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2
input_ids = input_ids.to(device='cuda', non_blocking=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=128,
use_cache=True)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
"prompt": cur_prompt,
"text": outputs,
"answer_id": ans_id,
"model_id": model_name,
"metadata": {}}) + "\n")
ans_file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
args = parser.parse_args()
# benchmark_name = args.question_file.split('/')[4]
eval_model(args)
import os
import json
import math
import torch
import argparse
import shortuuid
import pandas as pd
from PIL import Image
from tqdm import tqdm
from mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from mobilevlm.conversation import conv_templates, SeparatorStyle
from mobilevlm.model.mobilevlm import load_pretrained_model
from mobilevlm.utils import disable_torch_init, tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
all_options = ['A', 'B', 'C', 'D']
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def is_none(value):
if value is None:
return True
if type(value) is float and math.isnan(value):
return True
if type(value) is str and value.lower() == 'nan':
return True
if type(value) is str and value.lower() == 'none':
return True
return False
def get_options(row, options):
parsed_options = []
for option in options:
option_value = row[option]
if is_none(option_value):
break
parsed_options.append(option_value)
return parsed_options
def eval_model(args):
# Model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path)
questions = pd.read_table(os.path.expanduser(args.question_file))
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
args.conv_mode = args.conv_mode + '_mmtag'
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
for index, row in tqdm(questions.iterrows(), total=len(questions)):
options = get_options(row, all_options)
cur_option_char = all_options[:len(options)]
if args.all_rounds:
num_rounds = len(options)
else:
num_rounds = 1
for round_idx in range(num_rounds):
idx = row['index']
question = row['question']
hint = row['hint']
image = load_image_from_base64(row['image'])
if not is_none(hint):
question = hint + '\n' + question
for option_char, option in zip(all_options[:len(options)], options):
question = question + '\n' + option_char + '. ' + option
qs = cur_prompt = question
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
if args.single_pred_prompt:
if args.lang == 'cn':
qs = qs + '\n' + "请直接回答选项字母。"
else:
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
image_tensor = process_images([image], image_processor, model.config)[0]
# image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
# no_repeat_ngram_size=3,
max_new_tokens=1024,
use_cache=True)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
"round_id": round_idx,
"prompt": cur_prompt,
"text": outputs,
"options": options,
"option_char": cur_option_char,
"answer_id": ans_id,
"model_id": model_name,
"metadata": {}}) + "\n")
ans_file.flush()
# rotate options
options = options[1:] + options[:1]
cur_option_char = cur_option_char[1:] + cur_option_char[:1]
ans_file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--all-rounds", action="store_true")
parser.add_argument("--single-pred-prompt", action="store_true")
parser.add_argument("--lang", type=str, default="en")
args = parser.parse_args()
eval_model(args)
import os
import json
import math
import torch
import argparse
import shortuuid
from PIL import Image
from tqdm import tqdm
from mobilevlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from mobilevlm.conversation import conv_templates, SeparatorStyle
from mobilevlm.model.mobilevlm import load_pretrained_model
from mobilevlm.utils import disable_torch_init, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def eval_model(args):
# Model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path)
questions = json.load(open(os.path.expanduser(args.question_file), "r"))
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
for i, line in enumerate(tqdm(questions)):
idx = line["id"]
question = line['conversations'][0]
qs = question['value'].replace('<image>', '').strip()
cur_prompt = qs
if 'image' in line:
image_file = line["image"]
image = Image.open(os.path.join(args.image_folder, image_file))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
images = image_tensor.unsqueeze(0).half().cuda()
if getattr(model.config, 'mm_use_im_start_end', False):
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
cur_prompt = '<image>' + '\n' + cur_prompt
else:
images = None
if args.single_pred_prompt:
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=stopping_criteria,
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
# prompt for answer
if args.answer_prompter:
outputs_reasoning = outputs
input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=64,
use_cache=True,
stopping_criteria=[stopping_criteria])
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
outputs = outputs_reasoning + '\n The answer is ' + outputs
ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
"prompt": cur_prompt,
"text": outputs,
"answer_id": ans_id,
"model_id": model_name,
"metadata": {}}) + "\n")
ans_file.flush()
ans_file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--answer-prompter", action="store_true")
parser.add_argument("--single-pred-prompt", action="store_true")
args = parser.parse_args()
eval_model(args)
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from mobilevlm.model.mobilevlm import MobileVLMMetaModel, MobileVLMMetaForCausalLM
class MobileVLMConfig(LlamaConfig):
model_type = "mobilevlm"
class MobileLlamaModel(MobileVLMMetaModel, LlamaModel):
config_class = MobileVLMConfig
def __init__(self, config: LlamaConfig):
super(MobileLlamaModel, self).__init__(config)
class MobileLlamaForCausalLM(LlamaForCausalLM, MobileVLMMetaForCausalLM):
config_class = MobileVLMConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = MobileLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init() # Initialize weights and apply final processing
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_ids, attention_mask, past_key_values, inputs_embeds, labels = \
self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
AutoConfig.register("mobilevlm", MobileVLMConfig)
AutoModelForCausalLM.register(MobileVLMConfig, MobileLlamaForCausalLM)
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from transformers import AutoTokenizer, BitsAndBytesConfig
from mobilevlm.model.vision_encoder import build_vision_tower
from mobilevlm.model.vision_projector import build_vision_projector
from mobilevlm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, \
DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
class MobileVLMMetaModel:
def __init__(self, config):
super(MobileVLMMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)
def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
self.config.mm_vision_tower = model_args.vision_tower
vision_tower = model_args.vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
elif self.get_vision_tower().vision_tower_name != vision_tower:
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
vision_tower = self.vision_tower[0]
vision_tower.load_model()
else:
vision_tower = self.vision_tower
vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
# Build VisionTower
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
self.config.mm_hidden_size = vision_tower.hidden_size
# Build Vision-Projector
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
class MobileVLMMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, attention_mask, past_key_values, labels, images
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
return input_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images)
new_input_embeds = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
# FIXME: this is a hacky fix, for deepspeed zero3 to work
half_len = cur_input_ids.shape[0] // 2
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
new_input_embeds.append(cur_input_embeds)
if labels is not None:
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
cur_new_input_embeds = []
if labels is not None:
cur_labels = labels[batch_idx]
cur_new_labels = []
assert cur_labels.shape == cur_input_ids.shape
while image_token_indices.numel() > 0:
cur_image_features = image_features[cur_image_idx]
image_token_start = image_token_indices[0]
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
cur_labels = cur_labels[image_token_start+2:]
else:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
cur_labels = cur_labels[image_token_start+1:]
cur_image_idx += 1
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_input_ids = cur_input_ids[image_token_start+2:]
else:
cur_input_ids = cur_input_ids[image_token_start+1:]
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
if cur_input_ids.numel() > 0:
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
else:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
if labels is not None:
cur_new_labels.append(cur_labels)
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
new_input_embeds.append(cur_new_input_embeds)
if labels is not None:
cur_new_labels = torch.cat(cur_new_labels, dim=0)
new_labels.append(cur_new_labels)
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
max_len = max(x.shape[0] for x in new_input_embeds)
new_input_embeds_align = []
for cur_new_embed in new_input_embeds:
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
new_input_embeds_align.append(cur_new_embed)
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
if labels is not None:
new_labels_align = []
_new_labels = new_labels
for cur_new_label in new_labels:
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
new_labels_align.append(cur_new_label)
new_labels = torch.stack(new_labels_align, dim=0)
if attention_mask is not None:
new_attention_mask = []
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
new_attention_mask.append(cur_new_attention_mask)
attention_mask = torch.stack(new_attention_mask, dim=0)
assert attention_mask.shape == new_labels.shape
else:
new_input_embeds = torch.stack(new_input_embeds, dim=0)
if labels is not None:
new_labels = torch.stack(new_labels, dim=0)
if attention_mask is not None:
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
assert attention_mask.shape == new_input_embeds.shape[:2]
return None, attention_mask, past_key_values, new_input_embeds, new_labels
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
from mobilevlm.model.mobilellama import MobileLlamaForCausalLM
kwargs = {"device_map": device_map}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = MobileLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if 'v2' in getattr(model.config, "mm_projector_type", "ldpnet"):
vision_tower.load_image_processor()
elif not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
\ No newline at end of file
import os
import torch
import random
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) # dummy-load
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def load_image_processor(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
if self.select_feature.startswith('mtcv'):
num_select = int(self.select_feature.split('-')[-1])
return self.config.hidden_size * num_select
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
def build_vision_tower(model_cfg, **kwargs):
vision_tower = getattr(model_cfg, 'mm_vision_tower', getattr(model_cfg, 'vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
vision_tower_type = getattr(model_cfg, 'vision_tower_type', None)
if vision_tower_type == "clip":
return CLIPVisionTower(vision_tower, args=model_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
\ No newline at end of file
import re
import math
import torch
from torch import nn
from functools import partial
from timm.layers.norm_act import LayerNormAct2d
from torchvision.ops.misc import SqueezeExcitation as SElayer
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
class LDPBlock(nn.Module):
# Lightweight Downsample Projector Block
def __init__(self, config=None):
super().__init__()
inc, ouc = config.mm_hidden_size, config.hidden_size
layer_norm = partial(LayerNormAct2d, act_layer=None)
se_layer = partial(SElayer, scale_activation=nn.Hardsigmoid)
self.mlp = nn.Sequential(
nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc)
)
self.mb_block = nn.Sequential(
nn.Identity(),
InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer),
InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer)
)
def forward(self, x):
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
x = self.mlp(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.mb_block(x)
x = x.flatten(2).permute(0, 2, 1)
return x
class FeatureIRLayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class TokenDownLayer(nn.Module):
def __init__(self, shape) -> None:
super().__init__()
self.dwn = nn.Sequential(
nn.AdaptiveAvgPool2d(shape)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
assert h * h == num_tokens
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.dwn(x)
x = x.flatten(2).transpose(1, 2)
return x
class PosInjectLayer(nn.Module):
# https://github.com/Meituan-AutoML/Twins/blob/main/gvt.py
def __init__(self, in_dim: int, out_dim: int, stride: int = 1) -> None:
super().__init__()
self.peg = nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=True, groups=out_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
assert h * h == num_tokens
cnn_feat = x.transpose(1, 2).view(b, c, h, h)
x = self.peg(cnn_feat) + cnn_feat
x = x.flatten(2).transpose(1, 2)
return x
class LDPNetProjector(nn.Module):
def __init__(self, config=None):
super().__init__()
self.model = LDPBlock(config)
def forward(self, x):
return self.model(x)
class LDPNetV2Projector(nn.Module):
def __init__(self, config=None):
super().__init__()
inc, ouc = config.mm_hidden_size, config.hidden_size
self.mlp = FeatureIRLayer(inc, ouc)
self.dwn = TokenDownLayer((12, 12))
self.peg = PosInjectLayer(ouc, ouc, stride=1)
def forward(self, x):
x = self.mlp(x)
x = self.dwn(x)
x = self.peg(x)
return x
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
elif projector_type.startswith('mlp'):
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
elif projector_type.startswith('ldpnetv2'):
return LDPNetV2Projector(config)
elif projector_type.startswith('ldpnet'):
return LDPNetProjector(config)
raise ValueError(f'Unknown projector type: {projector_type}')
\ No newline at end of file
from typing import Optional, Tuple
import warnings
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
) # shape: (b, num_heads, s, head_dim)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment