Unverified Commit 1034c487 authored by Yixiao Fang's avatar Yixiao Fang Committed by GitHub
Browse files

[Refactor] Refactor instructblip (#227)

* refactor instructblip

* add post processor

* add forward

* fix lint

* update

* update
parent 02ce139b
...@@ -6,4 +6,44 @@ ...@@ -6,4 +6,44 @@
git clone https://github.com/salesforce/LAVIS.git git clone https://github.com/salesforce/LAVIS.git
cd ./LAVIS cd ./LAVIS
pip install -e . pip install -e .
``` ```
\ No newline at end of file
### Modify the config
Modify the config of InstructBlip, like model path of LLM and Qformer.
Then update `tasks.py` like the following code snippet.
```python
from mmengine.config import read_base
with read_base():
from .instructblip.instructblip_mmbench import (instruct_blip_dataloader,
instruct_blip_evaluator,
instruct_blip_load_from,
instruct_blip_model)
models = [instruct_blip_model]
datasets = [instruct_blip_dataloader]
evaluators = [instruct_blip_evaluator]
load_froms = [instruct_blip_load_from]
num_gpus = 8
num_procs = 8
launcher = 'pytorch' # or 'slurm'
```
### Start evaluation
#### Slurm
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
```
#### PyTorch
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval
```
from opencompass.multimodal.models.instructblip import (
InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor)
# dataloader settings # dataloader settings
val_pipeline = [ val_pipeline = [
dict(type='mmpretrain.torchvision/Resize', dict(type='mmpretrain.torchvision/Resize',
...@@ -9,24 +12,27 @@ val_pipeline = [ ...@@ -9,24 +12,27 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)), std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', dict(type='mmpretrain.PackInputs',
algorithm_keys=[ algorithm_keys=[
'question', 'category', 'l2-category', 'context', 'question', 'category', 'l2-category', 'context', 'index',
'index', 'options_dict', 'options', 'split' 'options_dict', 'options', 'split'
]) ])
] ]
dataset = dict(type='opencompass.MMBench', dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv', data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline) pipeline=val_pipeline)
dataloader = dict(batch_size=1, instruct_blip_dataloader = dict(batch_size=1,
num_workers=4, num_workers=4,
dataset=dataset, dataset=dataset,
collate_fn=dict(type='pseudo_collate'), collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings # model settings
model = dict( instruct_blip_model = dict(
type='blip2-vicuna-instruct-mmbench', type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor),
post_processor=dict(type=InstructBlipMMBenchPostProcessor),
freeze_vit=True, freeze_vit=True,
low_resource=False, low_resource=False,
llm_model='/path/to/vicuna-7b/', llm_model='/path/to/vicuna-7b/',
...@@ -35,11 +41,11 @@ model = dict( ...@@ -35,11 +41,11 @@ model = dict(
) )
# evaluation settings # evaluation settings
evaluator = [ instruct_blip_evaluator = [
dict( dict(
type='opencompass.DumpResults', type='opencompass.DumpResults',
save_path= # noqa: E251 save_path= # noqa: E251
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx') 'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
] ]
load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed'
from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench from .blip2_vicuna_instruct import InstructBlipInferencer
from .post_processor import InstructBlipMMBenchPostProcessor
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
__all__ = ['Blip2VicunaInstructMMBench'] __all__ = [
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
'InstructBlipMMBenchPostProcessor'
]
"""Requires Transformer 4.28 and above, implementation may change according the """Requires Transformer 4.28 and above, implementation may change according the
Llama implementation.""" Llama implementation."""
import logging import logging
import re
import mmengine
import torch import torch
import torch.nn as nn import torch.nn as nn
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
...@@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer ...@@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
from opencompass.registry import MM_MODELS from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench') @MM_MODELS.register_module('blip2-vicuna-instruct')
class Blip2VicunaInstructMMBench(Blip2Base): class InstructBlipInferencer(Blip2Base):
def __init__( def __init__(
self, self,
vit_model='eva_clip_g', prompt_constructor: dict,
img_size=224, post_processor: dict,
drop_path_rate=0, vit_model: str = 'eva_clip_g',
use_grad_checkpoint=False, img_size: int = 224,
vit_precision='fp16', drop_path_rate: float = 0,
freeze_vit=True, use_grad_checkpoint: bool = False,
num_query_token=32, vit_precision: str = 'fp16',
llm_model='', freeze_vit: bool = True,
sys_prompt='', num_query_token: int = 32,
prompt='', llm_model: str = '',
max_txt_len=128, sys_prompt: str = '',
max_output_txt_len=256, prompt: str = '',
qformer_text_input=True, max_txt_len: int = 128,
low_resource=False, max_output_txt_len: int = 256,
qformer_text_input: bool = True,
low_resource: bool = False,
mode: str = 'generation',
): ):
super().__init__() super().__init__()
self.mode = mode
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.tokenizer = self.init_tokenizer(truncation_side='left') self.tokenizer = self.init_tokenizer(truncation_side='left')
self.visual_encoder, self.ln_vision = self.init_vision_encoder( self.visual_encoder, self.ln_vision = self.init_vision_encoder(
...@@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base): ...@@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base):
self.qformer_text_input = qformer_text_input self.qformer_text_input = qformer_text_input
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def concat_text_input_output(self, input_ids, input_atts, output_ids, def concat_text_input_output(self, input_ids, input_atts, output_ids,
output_atts): output_atts):
input_part_targets_len = [] input_part_targets_len = []
...@@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base): ...@@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base):
temperature=1, temperature=1,
): ):
inputs = self.pack_inputs(batch) inputs = self.pack_inputs(batch)
image = inputs.pop('image') inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples'] data_samples = inputs['data_samples']
samples = {'image': image}
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
if data_samples[0].get('context') is not None:
contexts = [
data_sample.get('context') for data_sample in data_samples
]
prompt = [
context + ' ' + question + ' ' + option for context, question,
option in zip(contexts, questions, options)
]
else:
prompt = [
question + ' ' + option
for question, option in zip(questions, options)
]
self.llm_tokenizer.padding_side = 'left' self.llm_tokenizer.padding_side = 'left'
image = samples['image']
bs = image.size(0) bs = image.size(0)
if isinstance(prompt, str): if isinstance(prompt, str):
...@@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base): ...@@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base):
length_penalty=length_penalty, length_penalty=length_penalty,
num_return_sequences=num_captions, num_return_sequences=num_captions,
) )
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
output_text = self.llm_tokenizer.batch_decode(outputs, for i, data_sample in enumerate(data_samples):
skip_special_tokens=True) output_token = outputs[i]
output_text = [text.strip() for text in output_text] output_text = self.post_processor(output_token, self.llm_tokenizer)
output_text = self.post_process(output_text[0]) data_sample.pred_answer = output_text
data_sample = data_samples[0] data_samples[i] = data_sample
data_sample.pred_answer = output_text return data_samples
return data_sample
def post_process(self, output_text):
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
import re
import torch
class InstructBlipMMBenchPostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
# convert output id 0 to 2 (eos_token_id)
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text.strip())
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
from typing import List
from mmpretrain.structures import DataSample
class InstructBlipMMBenchPromptConstructor:
"""Prompt constructor for InstructBlip on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment