Unverified Commit 3a232db4 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

[Deperecate] Remove multi-modal related stuff (#1072)



* Remove MultiModal

* update index.rst

* update README

* remove mmbench codes

* update news

---------
Co-authored-by: default avatarLeymore <zfz-960727@163.com>
parent f1ee11de
import json
import os
import os.path as osp
import sys
from pathlib import Path
import clip
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from timm.models.vision_transformer import Block
from opencompass.registry import MM_MODELS
def load_package():
"""Load required packages from llama_adapter_v2_multimodal7b."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'LLaMA-Adapter')) # noqa
from llama_adapter_v2_multimodal7b.llama.llama import (ModelArgs,
Transformer)
from llama_adapter_v2_multimodal7b.llama.tokenizer import Tokenizer
from llama_adapter_v2_multimodal7b.llama.utils import sample_top_p
sys.path.pop(-1)
return ModelArgs, Transformer, Tokenizer, sample_top_p
ModelArgs, Transformer, Tokenizer, sample_top_p = load_package()
class LLaMA_adapter(nn.Module):
def __init__(self,
llama_ckpt_dir,
llama_tokenizer,
max_seq_len=512,
max_batch_size=1,
clip_model='ViT-L/14',
v_embed_dim=768,
v_depth=8,
v_num_heads=16,
v_mlp_ratio=4.0,
query_len=10,
query_layer=31,
w_bias=False,
w_lora=False,
lora_rank=16,
prompt_constructor=None,
post_processor=None):
super().__init__()
self.device = get_device()
# load llama configs
with open(os.path.join(llama_ckpt_dir, 'params.json'), 'r') as f:
params = json.loads(f.read())
model_args = ModelArgs(max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params)
# 1. clip and clip projector
self.clip, self.clip_transform = clip.load(clip_model)
clip_dim = self.clip.visual.proj.shape[1]
self.clip_proj = nn.Linear(clip_dim, v_embed_dim)
self.clip_proj_norm = nn.LayerNorm(v_embed_dim)
self.query_len = query_len
self.query_layer = query_layer
# 2. visual query, blocks and projector
self.visual_query = nn.Embedding(query_len, v_embed_dim)
self.visual_blocks = nn.ModuleList([
Block(v_embed_dim, v_num_heads, v_mlp_ratio, qkv_bias=True)
for _ in range(v_depth)
])
self.visual_proj = nn.Linear(v_embed_dim, model_args.dim)
self.visual_proj_norm = nn.LayerNorm(model_args.dim)
# 3. adapter query
self.adapter_query = nn.Embedding(query_len * query_layer,
model_args.dim)
# 4. tokenizer
self.tokenizer = Tokenizer(model_path=llama_tokenizer)
# 5. llama
model_args.vocab_size = self.tokenizer.n_words
model_args.w_bias = w_bias
model_args.w_lora = w_lora
model_args.lora_rank = lora_rank
torch.set_default_tensor_type(torch.cuda.HalfTensor)
self.llama = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
ckpts = sorted(Path(llama_ckpt_dir).glob('*.pth'))
for ckpt in ckpts:
ckpt = torch.load(ckpt, map_location='cpu')
self.llama.load_state_dict(ckpt, strict=False)
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def clip_encode_image(self, x):
# modified from CLIP
x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.clip.visual.positional_embedding.to(x.dtype)
x = self.clip.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
# preserve all spatial tokens
x = self.clip.visual.ln_post(x[:, :, :])
if self.clip.visual.proj is not None:
x = x @ self.clip.visual.proj
return x
def forward_visual(self, imgs):
clip_feats = self.clip_encode_image(imgs)
clip_feats = self.clip_proj_norm(self.clip_proj(clip_feats.float()))
visual_query = self.visual_query.weight.unsqueeze(0).repeat(
len(imgs), 1, 1)
visual_query = torch.cat([visual_query, clip_feats], dim=1)
for block in self.visual_blocks:
visual_query = block(visual_query)
visual_query = visual_query[:, :self.query_len, :]
visual_query = self.visual_proj(visual_query)
visual_query = self.visual_proj_norm(visual_query)
return visual_query
@torch.inference_mode()
def forward(self, visual_query, tokens, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.llama.tok_embeddings(tokens)
freqs_cis = self.llama.freqs_cis.to(h.device)
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
mask = None
mask = torch.full((1, 1, seqlen, seqlen),
float('-inf'),
device=h.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
for layer in self.llama.layers[:-1 * self.query_layer]:
h = layer(h, start_pos, freqs_cis, mask)
adapter = self.adapter_query.weight.reshape(self.query_layer,
self.query_len,
-1).unsqueeze(1)
adapter_index = 0
for layer in self.llama.layers[-1 * self.query_layer:]:
dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1)
dynamic_adapter = dynamic_adapter + visual_query
h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter)
adapter_index = adapter_index + 1
h = self.llama.norm(h)
output = self.llama.output(h[:, -1, :])
return output.float()
def pack_inputs(self, batch):
images = [image.unsqueeze(0) for image in batch['inputs']]
data_samples = [data_sample for data_sample in batch['data_samples']]
images = torch.cat(images, dim=0).to(get_device())
inputs = {'image': images, 'data_samples': data_samples}
return inputs
@torch.inference_mode()
def generate(self, batch):
max_gen_len = 256
temperature = 0.1
top_p = 0.75
inputs = self.pack_inputs(batch)
inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompts = inputs['prompt']
data_samples = inputs['data_samples']
data_sample = data_samples[0]
imgs = image
# import pdb;pdb.set_trace()
bsz = len(imgs)
params = self.llama.params
with torch.cuda.amp.autocast():
visual_query = self.forward_visual(imgs)
# import pdb;pdb.set_trace()
if isinstance(prompts[0], str):
prompts = [
self.tokenizer.encode(x, bos=True, eos=False) for x in prompts
]
# import pdb;pdb.set_trace()
min_prompt_size = min([len(t) for t in prompts])
max_prompt_size = max([len(t) for t in prompts])
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full((bsz, total_len),
self.tokenizer.pad_id).cuda().long()
# import pdb;pdb.set_trace()
for k, t in enumerate(prompts):
if len(t) <= total_len:
tokens[k, :len(t)] = torch.tensor(t).cuda().long()
else:
tokens[k, :total_len] = torch.tensor(
t[:total_len]).cuda().long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
with torch.cuda.amp.autocast():
logits = self.forward(visual_query,
tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
next_token = torch.where(input_text_mask[:, cur_pos],
tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
# trick: early stop if bsz==1
if bsz == 1 and next_token[0] == self.tokenizer.eos_id:
break
prev_pos = cur_pos
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[len(prompts[i]):len(prompts[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[:t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
output_text = self.post_processor(decoded[0])
data_sample.pred_answer = output_text
return data_sample
@MM_MODELS.register_module('LLaMA-adapter-v2')
class LLaMA_adapter_v2(nn.Module):
def __init__(self,
llama_dir,
prompt_constructor: dict,
post_processor: dict,
model_path: str = 'llama_adapter_v2_multimodal7b',
name: str = 'LORA-BIAS-7B',
mode: str = 'generation',
device='cuda' if torch.cuda.is_available() else 'cpu',
download_root='ckpts'):
super().__init__()
assert name in ['LORA-BIAS-7B', 'BIAS-7B', 'CAPTION-7B']
# BIAS-7B or https://xxx/sha256_BIAS-7B.pth -> 7B
llama_type = name.split('.')[0].split('-')[-1]
llama_ckpt_dir = os.path.join(llama_dir, llama_type)
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model')
# load llama_adapter weights and model_cfg
print(f'Loading LLaMA-Adapter from {llama_dir}')
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
model_path = osp.join(current_folder_path, 'LLaMA-Adapter', model_path)
ckpt_root = osp.join(model_path, download_root)
ckpt_map = {
'LORA-BIAS-7B':
'1bcbffc43484332672092e0024a8699a6eb5f558161aebf98a7c6b1db67224d1_LORA-BIAS-7B.pth', # noqa: E501
'BIAS-7B':
'7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth', # noqa: E501
'CAPTION-7B':
'5088aeb63a89746b90bcfd5cb819e1c7411b2771b267c6d131ce73e250a8abf0_CAPTION-7B.pth' # noqa: E501
}
ckpt = torch.load(osp.join(ckpt_root, ckpt_map[name]),
map_location='cpu')
model_cfg = ckpt.get('config', {})
self.model = LLaMA_adapter(
llama_ckpt_dir,
llama_tokenzier_path,
max_seq_len=512,
max_batch_size=1,
clip_model='ViT-L/14',
v_embed_dim=768,
v_depth=8,
v_num_heads=16,
v_mlp_ratio=4.0,
query_len=10,
query_layer=31,
w_bias=model_cfg.get('w_bias', False),
w_lora=model_cfg.get('w_lora', False),
lora_rank=model_cfg.get('lora_rank', 16),
prompt_constructor=prompt_constructor,
post_processor=post_processor,
)
self.model.load_state_dict(ckpt['model'], strict=False)
self.mode = mode
def forward(self, batch):
if self.mode == 'generation':
return self.model.generate(batch)
import torch
class LlamaAadapterMMBenchPostProcessor:
""""Post processor for Llama Aadapter V2 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor) -> str:
if len(output_token) >= 2:
if output_token[1] == '.':
output_token = output_token[2:].strip()
return output_token
from typing import List
from mmpretrain.structures import DataSample
class LlamaAadapterMMBenchPromptConstructor:
"""Prompt constructor for Llama Adapter v2 on MMBench.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
# import pdb;pdb.set_trace()
question = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
if data_samples[0].get('context') is not None:
context = [
data_sample.get('context') for data_sample in data_samples
]
else:
context = [''] * len(data_samples)
prompts = []
for cur_context, cur_question, cur_options in zip(
context, question, options):
prompts.append(cur_context + ' ' + cur_question + ' ' +
cur_options) # noqa
return prompts
from .llava import LLaVA
from .post_processor import LLaVABasePostProcessor, LLaVAVSRPostProcessor
from .prompt_constructor import (LLaVABasePromptConstructor,
LLaVAMMBenchPromptConstructor,
LLaVAScienceQAPromptConstructor,
LLaVAVQAPromptConstructor)
__all__ = [
'LLaVA', 'LLaVABasePromptConstructor', 'LLaVAMMBenchPromptConstructor',
'LLaVABasePostProcessor', 'LLaVAVQAPromptConstructor',
'LLaVAScienceQAPromptConstructor', 'LLaVAVSRPostProcessor'
]
import importlib
import os
import sys
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import StoppingCriteria
from opencompass.registry import MM_MODELS
IMAGE_TOKEN_INDEX = -200
def load_package():
"""Load required packages from LLaVA."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'LLaVA')) # noqa
return
class KeywordsStoppingCriteria(StoppingCriteria):
"""Keyword stopping criteria implemented for llava."""
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
**kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(output_ids[:,
self.start_len:],
skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
@MM_MODELS.register_module('llava')
class LLaVA(nn.Module):
"""Inference code of LLaVA. Need to clone LLaVA official repo first. Please
check out the README in config.
Args:
model_path (str): The path of llava checkpoint.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
"""
def __init__(
self,
model_path: str,
prompt_constructor: dict,
post_processor: dict,
is_caption_task: bool = False,
) -> None:
super().__init__()
self.dtype = torch.float16
self.is_caption_task = is_caption_task
# load LLaVA modules
load_package()
mm_utils = importlib.import_module('llava.mm_utils')
builder = importlib.import_module('llava.model.builder')
# load pretrained LLaVA
# Note: When encounters with device related errors,
# try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
model_name = mm_utils.get_model_name_from_path(model_path)
tokenizer, model, _, _ = builder.load_pretrained_model(
model_path, None, model_name)
vision_tower = model.get_vision_tower()
vision_tower.to(device=get_device(), dtype=self.dtype)
model.to(device=get_device(), dtype=self.dtype)
# load prompt constructor and post processor
if 'v1' in model_path.lower():
conv_mode = 'llava_v1'
elif 'mpt' in model_path.lower():
conv_mode = 'mpt_multimodal'
else:
conv_mode = 'multimodal'
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
False)
prompt_constructor.update({
'conv_mode': conv_mode,
'mm_use_im_start_end': mm_use_im_start_end
})
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.model = model
self.tokenizer = tokenizer
def generate(self, batch):
prompt, stop_str = self.prompt_constructor(batch)
keywords = [stop_str]
data_sample = batch['data_samples'][0]
image = batch['inputs'][0].unsqueeze(0)
if image is not None:
images = image.to(get_device())
else:
images = None
mm_utils = importlib.import_module('llava.mm_utils')
input_ids = mm_utils.tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(get_device())
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images.half(),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids !=
output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids' # noqa
)
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
skip_special_tokens=True)[0]
output_text = self.post_processor(outputs, stop_str)
if self.is_caption_task:
data_sample.pred_caption = output_text
else:
data_sample.pred_answer = output_text
return data_sample
def forward(self, batch):
return self.generate(batch)
class LLaVABasePostProcessor:
"""Base post processor for LLaVA on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, outputs: str, stop_str: str) -> str:
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
output_text = outputs.strip()
return output_text
class LLaVAVSRPostProcessor(LLaVABasePostProcessor):
"""VSR post processor for LLaVA on MMBench."""
def __init__(self) -> None:
super().__init__()
def __call__(self, outputs: str, stop_str: str) -> str:
output_text = super().__call__(outputs, stop_str)
if 'yes' in output_text.lower():
return 'yes'
elif 'no' in output_text.lower():
return 'no'
else:
return 'unknown'
import importlib
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
class LLaVABasePromptConstructor:
"""Base prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
conversation = importlib.import_module('llava.conversation')
self.conv_templates = conversation.conv_templates
self.conv_mode = conv_mode
self.mm_use_im_start_end = mm_use_im_start_end
self.SeparatorStyle = conversation.SeparatorStyle
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> tuple:
"""Construct prompt.
Args:
inputs (dict): Input data containing images and data_samples.
Returns:
tuple: A tuple containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
assert len(data_samples) == 1
prompt = self._build_prompt(data_samples[0])
if self.mm_use_im_start_end:
prompt = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + '\n' + prompt)
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt # noqa
conv = self.conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
output_prompt = conv.get_prompt()
stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2 # noqa
return output_prompt, stop_str
def _build_prompt(self, data_sample):
return self.reply_prompt
class LLaVAMMBenchPromptConstructor(LLaVABasePromptConstructor):
"""MMBench prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
super().__init__(conv_mode, mm_use_im_start_end, reply_prompt)
def _build_prompt(self, data_sample):
question = data_sample.get('question')
options = data_sample.get('options')
context = data_sample.get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
prompt += self.reply_prompt
return prompt
class LLaVAVQAPromptConstructor(LLaVABasePromptConstructor):
"""VQA prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
super().__init__(conv_mode, mm_use_im_start_end, reply_prompt)
def _build_prompt(self, data_sample):
prompt = data_sample.get('question')
prompt += self.reply_prompt
return prompt
class LLaVAScienceQAPromptConstructor(LLaVABasePromptConstructor):
"""ScienceQA prompt constructor for LLaVA.
Args:
conv_mode (str): Version control args for different version of LLaVA.
mm_use_im_start_end (bool):
Config arg. Use start and end token when build prompt or not.
reply_prompt (str): Reply prompt added at the end. (Default: '')
"""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self,
conv_mode: str,
mm_use_im_start_end: bool,
reply_prompt: str = '') -> None:
super().__init__(conv_mode, mm_use_im_start_end, reply_prompt)
def _build_prompt(self, data_sample):
question = data_sample.get('question')
choices = data_sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
choices = 'Choices: ' + ' '.join(choices) + '\n'
context = 'Context: ' + data_sample.get('hint') + '\n'
prompt = context + question + choices + self.reply_prompt
return prompt
from .minigpt_4 import MiniGPT4Inferencer
from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
MiniGPT4MMBenchPostProcessor,
MiniGPT4MMEPostProcessor,
MiniGPT4ScienceQAPostProcessor,
MiniGPT4VQAPostProcessor,
MiniGPT4VSRPostProcessor)
from .prompt_constructor import MiniGPT4VSRPromptConstructor # noqa
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
MiniGPT4MMBenchPromptConstructor,
MiniGPT4MMEPromptConstructor,
MiniGPT4ScienceQAPromptConstructor,
MiniGPT4SEEDBenchPromptConstructor,
MiniGPT4VQAPromptConstructor)
__all__ = [
'MiniGPT4Inferencer', 'MiniGPT4MMBenchPostProcessor',
'MiniGPT4MMBenchPromptConstructor', 'MiniGPT4COCOCaotionPromptConstructor',
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor',
'MiniGPT4MMEPostProcessor', 'MiniGPT4MMEPromptConstructor'
]
import os
import sys
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import StoppingCriteriaList
from opencompass.registry import MM_MODELS
from .utils import StoppingCriteriaSub
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def load_package():
"""Load required packages from MiniGPT-4."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'MiniGPT-4')) # noqa
try:
# the latest version of MiniGPT4
from minigpt4.models.minigpt4 import MiniGPT4
except ImportError:
# the old version of MiniGPT4
from minigpt4.models.mini_gpt4 import MiniGPT4
sys.path.pop(-1)
return MiniGPT4
MiniGPT4 = load_package()
@MM_MODELS.register_module('minigpt-4')
class MiniGPT4Inferencer(MiniGPT4):
"""Inference code of MiniGPT-4.
Args:
llama_model (str): The path of vicuna path.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
do_sample (bool): Whether use sampling. Defaults to False.
max_length (int): The max length of output. Defaults to 30.
img_size (int): The size of image. Defaults to 224.
low_resource (bool): Whether loaded in low precision.
Defaults to False.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
"""
def __init__(self,
llama_model: str,
prompt_constructor: dict,
post_processor: dict,
do_sample: bool = False,
max_length: int = 30,
img_size: int = 224,
low_resource: bool = False,
is_caption_task: bool = False,
mode: str = 'generation',
n_segments: int = 1) -> None:
super().__init__(llama_model=llama_model,
low_resource=low_resource,
img_size=img_size)
self.mode = mode
self.n_segments = n_segments
cur_device = get_device()
stop_words_ids = [
torch.tensor([835]).to(cur_device),
torch.tensor([2277, 29937]).to(cur_device),
]
self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)])
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.do_sample = do_sample
self.max_length = max_length
self.is_caption_task = is_caption_task
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
elif self.mode == 'loss':
return self.loss(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def encode_img(self, image):
device = image.device
with self.maybe_autocast():
if image.dim() == 5:
inputs_llama, atts_llama = [], []
for j in range(image.size(2)):
this_frame = image[:, :, j, :, :]
frame_embeds = self.ln_vision(
self.visual_encoder(this_frame))
frame_atts = torch.ones(frame_embeds.size()[:-1],
dtype=torch.long).to(image.device)
query_tokens = self.query_tokens.expand(
frame_embeds.shape[0], -1, -1)
frame_query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=frame_embeds,
encoder_attention_mask=frame_atts,
return_dict=True,
)
frame_inputs_llama = self.llama_proj(
frame_query_output.last_hidden_state[:, :query_tokens.
size(1), :])
frame_atts_llama = torch.ones(
frame_inputs_llama.size()[:-1],
dtype=torch.long).to(image.device)
inputs_llama.append(frame_inputs_llama)
atts_llama.append(frame_atts_llama)
inputs_llama = torch.cat(inputs_llama, dim=1)
atts_llama = torch.cat(atts_llama, dim=1)
else:
image_embeds = self.ln_vision(
self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(
image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
atts_llama = torch.ones(inputs_llama.size()[:-1],
dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def pack_inputs(self, batch):
images = [image.unsqueeze(0) for image in batch['inputs']]
data_samples = [data_sample for data_sample in batch['data_samples']]
images = torch.cat(images, dim=0).to(get_device())
inputs = {'image': images, 'data_samples': data_samples}
return inputs
def generate(self, batch):
inputs = self.pack_inputs(batch)
inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples']
# The main process of generation
img_embeds, _ = self.encode_img(image)
prompt_segs = prompt.split('<ImageHere>')
prompt_seg_tokens = [
self.llama_tokenizer(seg,
return_tensors='pt',
add_special_tokens=i == 0).
to(self.llama_model.model.embed_tokens.weight.device).input_ids
for i, seg in enumerate(prompt_segs)
]
prompt_seg_embs = [
self.llama_model.model.embed_tokens(seg)
for seg in prompt_seg_tokens
]
prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]]
prompt_embs = torch.cat(prompt_seg_embs, dim=1)
# generate output
outputs = self.llama_model.generate(
inputs_embeds=prompt_embs,
max_length=self.max_length,
num_beams=5,
do_sample=self.do_sample,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=-1.0,
temperature=1.0,
stopping_criteria=self.stopping_criteria,
num_return_sequences=1)
for i, data_sample in enumerate(data_samples):
output_token = outputs[i]
output_text = self.post_processor(output_token,
self.llama_tokenizer)
if self.is_caption_task:
data_sample.pred_caption = output_text
else:
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples
def loss(self, batch):
inputs = self.pack_inputs(batch)
inputs = self.prompt_constructor(inputs)
image = inputs['image']
batch_size = image.size(0)
prompt = inputs['prompt']
data_samples = inputs['data_samples']
choices = data_samples[0].choices
with torch.no_grad():
img_embeds, atts_img = self.encode_img(image)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
prompt)
self.llama_tokenizer.padding_side = 'right'
n_cands = len(choices)
losses = []
for n in range(self.n_segments):
seg_len = n_cands // self.n_segments
if n == (self.n_segments - 1):
seg_len = n_cands - seg_len * (self.n_segments - 1)
to_regress_tokens = self.llama_tokenizer(
choices,
return_tensors='pt',
padding='longest',
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False).to(image.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids ==
self.llama_tokenizer.pad_token_id, -100)
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
dtype=torch.long).to(image.device).fill_(
-100) # plus one for bos
)
empty_targets = empty_targets.repeat_interleave(seg_len, dim=0)
targets = torch.cat([empty_targets, targets], dim=1)
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device
) * self.llama_tokenizer.bos_token_id
bos_embeds = self.llama_model.model.embed_tokens(bos)
bos_embeds = bos_embeds.repeat_interleave(seg_len, dim=0)
img_embeds = img_embeds.repeat_interleave(seg_len, dim=0)
atts_bos = atts_img[:, :1]
atts_bos = atts_bos.repeat_interleave(seg_len, dim=0)
atts_img = atts_img.repeat_interleave(seg_len, dim=0)
to_regress_embeds = self.llama_model.model.embed_tokens(
to_regress_tokens.input_ids)
inputs_embeds = torch.cat(
[bos_embeds, img_embeds, to_regress_embeds], dim=1)
attention_mask = torch.cat(
[atts_bos, atts_img, to_regress_tokens.attention_mask],
dim=1)
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction='none',
)
loss = outputs.loss
loss = loss.view(targets.size(0), -1).sum(1)
loss = loss.reshape(batch_size, seg_len)
losses.append(loss)
# losses of 4 choices
losses = torch.cat(losses, dim=-1)[0]
for i, data_sample in enumerate(data_samples):
data_sample.losses = losses
data_samples[i] = data_sample
return data_samples
import random
import re
import torch
class MiniGPT4MMBenchPostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text)
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
class MiniGPT4COCOCaptionPostProcessor:
""""Post processor for MiniGPT-4 on COCO Caption."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.split('. ')[0]
output_text = output_text.strip('<Img>')
output_text = output_text.strip()
return output_text
class MiniGPT4ScienceQAPostProcessor:
""""Post processor for MiniGPT-4 on ScienceQA."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
pattern = re.compile(r'\(([A-Z])\)')
output_text = pattern.findall(output_text)
if len(output_text) == 0:
output_text = random.choice(['A', 'B', 'C', 'D'])
else:
output_text = output_text[0]
return output_text
class MiniGPT4VQAPostProcessor:
""""Post processor for MiniGPT-4 on VQA."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
return output_text
class MiniGPT4VSRPostProcessor:
""""Post processor for MiniGPT-4 on VSR."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token, add_special_tokens=False)
pattern = r'yes|no|Yes|No'
output_text = re.findall(pattern, output_text)
if len(output_text) > 0:
output_text = output_text[0].lower()
return output_text
class MiniGPT4MMEPostProcessor(MiniGPT4MMBenchPostProcessor):
""""Post processor for MiniGPT-4 on MME."""
def __init__(self) -> None:
super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
response = super().__call__(output_token, tokenizer)
# extract yes or no, copy from MME official evaluation script
prefix_pred_ans = response[:4].lower()
if 'yes' in prefix_pred_ans:
pred_label = 'yes'
elif 'no' in prefix_pred_ans:
pred_label = 'no'
else:
pred_label = 'other'
return pred_label
from typing import List
from mmpretrain.structures import DataSample
class MiniGPT4MMBenchPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt
class MiniGPT4COCOCaotionPromptConstructor(MiniGPT4MMBenchPromptConstructor):
"""Prompt constructor for MiniGPT-4 on COCO Caption."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
return prompt
class MiniGPT4ScienceQAPromptConstructor(MiniGPT4MMBenchPromptConstructor):
"""Prompt constructor for MiniGPT-4 on ScienceQA."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
'Question: ' + data_sample.get('question') + '\n'
for data_sample in data_samples
] # noqa
choices = [data_sample.get('choices') for data_sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
] # noqa
contexts = [
'Context: ' + data_sample.get('hint') + '\n'
for data_sample in data_samples
] # noqa
question = questions[0]
choice = choices[0]
context = contexts[0]
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is' # noqa
return prompt
class MiniGPT4VQAPromptConstructor(MiniGPT4MMBenchPromptConstructor):
"""Prompt constructor for MiniGPT-4 on VQA."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt # noqa
return prompt
class MiniGPT4VSRPromptConstructor(MiniGPT4MMBenchPromptConstructor):
"""Prompt constructor for MiniGPT-4 on VSR."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt # noqa
return prompt
class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
return prompt
class MiniGPT4MMEPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MME.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self) -> None:
self.system_prompt = (
'Give the following image: <Img>ImageContent</Img>.'
'You will be able to see the image once I provide it to you.'
'Please answer my questions.')
self.sep = '###'
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
question = data_samples[0].get('question')
prompt = self.system_prompt + self.sep
prompt += 'Human: ' + question + ' ' + '<Img><ImageHere></Img>' + ' ' + self.sep # noqa
prompt += 'Assistant: '
return prompt
import os
import re
import timm.models.hub as timm_hub
import torch
import torch.distributed as dist
from mmengine.dist import is_distributed, is_main_process
from transformers import StoppingCriteria
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def download_cached_file(url, check_hash=True, progress=False):
"""Download a file from a URL and cache it locally.
If the file already exists, it is not downloaded again. If distributed,
only the main process downloads the file, and the other processes wait for
the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_distributed():
dist.barrier()
return get_cached_file_path()
def is_url(input_url):
"""Check if an input string is a url.
look for http(s):// and ignoring the case
"""
is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None
return is_url
from .mplug_owl_7b import MplugOwl
from .post_processor import MplugOwlMMBenchPostProcessor
from .prompt_constructor import MplugOwlMMBenchPromptConstructor # noqa
__all__ = [
'MplugOwl', 'MplugOwlMMBenchPostProcessor',
'MplugOwlMMBenchPromptConstructor'
]
import os
import sys
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from opencompass.registry import MM_MODELS
def load_package():
"""Load required packages from llama_adapter_v2_multimodal7b."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'mPLUG-Owl')) # noqa
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl.processing_mplug_owl import (MplugOwlImageProcessor,
MplugOwlProcessor)
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
sys.path.pop(-1)
return MplugOwlForConditionalGeneration, MplugOwlImageProcessor, MplugOwlProcessor, MplugOwlTokenizer # noqa
MplugOwlForConditionalGeneration, MplugOwlImageProcessor, MplugOwlProcessor, MplugOwlTokenizer = load_package( # noqa
) # noqa
@MM_MODELS.register_module('mplug_owl_7b')
class MplugOwl(nn.Module):
def __init__(self,
prompt_constructor: dict,
post_processor: dict,
model_path='MAGAer13/mplug-owl-llama-7b',
mode: str = 'generation'):
super().__init__()
pretrained_ckpt = model_path
# import pdb;pdb.set_trace()
print(pretrained_ckpt)
self.model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
).cuda()
self.image_processor = MplugOwlImageProcessor.from_pretrained(
pretrained_ckpt)
self.tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt)
self.processor = MplugOwlProcessor(self.image_processor,
self.tokenizer)
self.generate_kwargs = {
'do_sample': False,
'top_k': 5,
'max_length': 20,
'num_beams': 3,
}
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.mode = mode
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
def generate(self, batch):
images = [image.unsqueeze(0) for image in batch['inputs']]
data_samples = [data_sample for data_sample in batch['data_samples']]
images = torch.cat(images, dim=0).to(get_device())
inputs = {'image': images, 'data_samples': data_samples}
inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt'][0]
data_samples = inputs['data_samples']
data_sample = data_samples[0]
owl_template = """The following is a conversation
between a curious human and AI assistant.
The assistant gives helpful, detailed, and
polite answers to the user's questions.
Human: <image>
Human: {text_input}
AI: """
prompt = owl_template.format(text_input=prompt)
inputs = self.processor(text=[prompt], return_tensors='pt')
inputs['pixel_values'] = image
# inputs['pixel_values'] = torch.zeros_like(samples['image'])
inputs = {
k: v.bfloat16() if v.dtype == torch.float else v
for k, v in inputs.items()
}
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
res = self.model.generate(**inputs, **self.generate_kwargs)
output_text = self.tokenizer.decode(res.tolist()[0],
skip_special_tokens=True)
output_text = self.post_processor(output_text)
data_sample.pred_answer = output_text
return data_sample
import re
import torch
class MplugOwlMMBenchPostProcessor:
""""Post processor for MplugOwl on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor) -> str:
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_token)
if len(res) > 0:
output_token = res[0][:-1]
return output_token
from typing import List
from mmpretrain.structures import DataSample
class MplugOwlMMBenchPromptConstructor:
"""Prompt constructor for MplugOwl on MMBench.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
question = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
if data_samples[0].get('context') is not None:
context = [
data_sample.get('context') for data_sample in data_samples
]
else:
context = [''] * len(data_samples)
prompts = []
for cur_context, cur_question, cur_options in zip(
context, question, options):
prompts.append(cur_context + ' ' + cur_question + ' ' +
cur_options) # noqa
return prompts
from .openflamingo import OpenFlamingoInferencer
from .post_processor import OpenFlamingoVSRPostProcessor
from .prompt_constructor import (OpenFlamingoCaptionPromptConstructor,
OpenFlamingoMMBenchPromptConstructor,
OpenFlamingoScienceQAPromptConstructor,
OpenFlamingoVQAPromptConstructor)
__all__ = [
'OpenFlamingoInferencer', 'OpenFlamingoMMBenchPromptConstructor',
'OpenFlamingoCaptionPromptConstructor', 'OpenFlamingoVQAPromptConstructor',
'OpenFlamingoScienceQAPromptConstructor', 'OpenFlamingoVSRPostProcessor'
]
import re
from typing import List, Optional, Union
import mmengine
import torch
from mmpretrain.models.multimodal import Flamingo
from mmpretrain.structures import DataSample
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('openflamingo')
class OpenFlamingoInferencer(Flamingo):
"""Inference code of OpenFlamingo.
Args:
prompt_constructor (optional, dict): The config of prompt constructor.
Defaults to None.
post_processor (optional, dict): The config of post processor.
Defaults to None.
mode (str): The mode of inference. Defaults to 'generation'.
"""
def __init__(self,
prompt_constructor: dict,
post_processor: Optional[dict] = None,
mode: str = 'generation',
**kwargs):
super().__init__(**kwargs)
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
else:
self.post_processor = None
self.mode = mode
def preprocess_text(self, data_samples: List[DataSample],
device: torch.device) -> List[DataSample]:
"""Preprocess text in advance before fed into language model.
Args:
data_samples (List[DataSample]): The annotation
data of every samples. Defaults to None.
device (torch.device): Device for text to put on.
Returns:
List[DataSample]: Return list of data samples.
"""
prompts = self.prompt_constructor(data_samples)
self.tokenizer.padding_side = 'left'
input_text = self.tokenizer(
prompts,
padding='longest',
truncation=True,
return_tensors='pt',
max_length=2000,
).to(device)
return input_text
def post_process(
self, outputs: torch.Tensor,
data_samples: Optional[List[DataSample]]) -> List[DataSample]:
"""Perform post process for outputs for different task.
Args:
outputs (torch.Tensor): The generated outputs.
data_samples (List[DataSample], optional): The annotation
data of every samples.
Returns:
List[DataSample]: Return list of data samples.
"""
outputs = self.tokenizer.batch_decode(outputs,
skip_special_tokens=True)
if data_samples is None:
data_samples = [DataSample() for _ in range(len(outputs))]
for output, data_sample in zip(outputs, data_samples):
# remove text pattern
if self.task == 'caption':
data_sample.pred_caption = re.split('Output', output,
1)[0].replace('"', '')
if self.post_processor:
data_sample.pred_caption = self.post_processor(
data_sample.pred_caption)
elif self.task == 'vqa':
data_sample.pred_answer = re.split('Question|Answer', output,
1)[0]
if self.post_processor:
data_sample.pred_answer = self.post_processor(
data_sample.pred_answer)
return data_samples
def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]:
if self.mode == 'generation':
return self.generate(batch)
else:
raise RuntimeError(f'Unsupported mode: {self.mode}')
def generate(self, batch: dict) -> Union[DataSample, List[DataSample]]:
batch = self.data_preprocessor(batch, False)
images = batch['images']
data_samples = batch['data_samples']
return self.predict(images, data_samples)
class OpenFlamingoVSRPostProcessor:
"""VSR post processor for Openflamingo."""
def __init__(self) -> None:
pass
def __call__(self, raw_response: str) -> str:
if 'yes' in raw_response.lower():
return 'yes'
elif 'no' in raw_response.lower():
return 'no'
else:
return 'unknown'
from typing import Optional
from mmpretrain.structures import DataSample
class OpenFlamingoMMBenchPromptConstructor:
"""MMBench prompt constructor for OpenFlamingo."""
def __init__(self) -> None:
pass
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
prompts = []
question = sample.get('question')
option = sample.get('options')
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
if sample.get('context') is not None:
prompt = sample.get('context') + ' ' + prompt
prompts.append(prompt)
return prompts
class OpenFlamingoCaptionPromptConstructor:
"""Caption prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Output:A child holding a flowered umbrella and petting a yak.<|endofchunk|>' # noqa
'Output:The child is holding a brush close to his mouth.<|endofchunk|>' # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
prompts = []
prompt = '<image>Output:'
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoVQAPromptConstructor:
"""VQA prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Question:Is the sky dark? Short Answer:yes<|endofchunk|>' # noqa: E501
'Question:What is on the white wall? Short Answer:pipe<|endofchunk|>' # noqa: E501
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
prompts = []
for sample in data_samples:
question = sample.get('question')
prompt = '<image>Question:{} Short Answer:'.format(question)
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoScienceQAPromptConstructor:
"""ScienceQA prompt constructor for OpenFlamingo."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
"Context:Question:Which of these states is farthest north? Choices:['(A) West Virginia' '(B) Louisiana' '(C) Arizona' '(D) Oklahoma'] Answer with a single character: A<|endofchunk|>" # noqa
'Context:The diagrams below show two pure samples of gas in identical closed, rigid containers. Each colored ball represents one gas particle. Both samples have the same number of particles.' # noqa
"Question:Compare the average kinetic energies of the particles in each sample. Which sample has the higher temperature? Choices:'[(A) neither' '(B) sample A' '(C) sample B'] Answer with a single character: C<|endofchunk|>" # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
question = sample.get('question')
choices = sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
hint = sample.get('hint')
prompts = []
prompt = '<image>Context:{} Question:{} Choices:{}'.format(
hint, question, choices)
prompt += ' Answer with a single character:'
prompts.append(self.shot_prompt + prompt)
return prompts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment