Unverified Commit a4d68407 authored by Li Bo's avatar Li Bo Committed by GitHub
Browse files

[Feat] Add Otter to OpenCompass MMBench Evaluation (#232)

* add otter model for opencompass mmbench

* add docs

* add readme docs

* debug for otter opencomass eval

* delete unused folders

* change to default data path

* remove unused files

* remove unused files

* update

* update config file

* flake8 lint formated and add prompt generator

* add prompt generator to config

* add a specific postproecss

* add post processor

* add post processor

* add post processor

* update according to suggestions

* remove unused redefinition
parent 7ca6ba62
# OTTER: Multi-modal In-context Instruction Tuning.
### Prepare the environment
```sh
cd opencompass/multimodal/models/otter
git clone https://github.com/Luodian/Otter.git
```
Then create a new conda environment and prepare the environement according to this [doc](https://github.com/Luodian/Otter)
### Start evaluation
#### Slurm
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
```
#### PyTorch
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval
```
\ No newline at end of file
# dataloader settings
from opencompass.multimodal.models.otter import (
OTTERMMBenchPromptConstructor, OTTERMMBenchPostProcessor)
val_pipeline = [
dict(type="mmpretrain.torchvision/Resize", size=(224, 224), interpolation=3),
dict(type="mmpretrain.torchvision/ToTensor"),
dict(
type="mmpretrain.torchvision/Normalize",
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
dict(
type="mmpretrain.PackInputs",
algorithm_keys=["question", "answer", "options", "category", "l2-category", "context", "index", "options_dict"],
),
]
dataset = dict(
type="opencompass.MMBenchDataset", data_file="/path/to/mmbench/mmbench_test_20230712.tsv", pipeline=val_pipeline
)
otter_9b_mmbench_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type="pseudo_collate"),
sampler=dict(type="DefaultSampler", shuffle=False),
)
# model settings
otter_9b_mmbench_model = dict(
type="otter-9b",
model_path="/path/to/OTTER-Image-MPT7B/", # noqa
load_bit="bf16",
prompt_constructor=dict(type=OTTERMMBenchPromptConstructor,
model_label='GPT',
user_label='User'),
post_processor=dict(type=OTTERMMBenchPostProcessor)
)
# evaluation settings
otter_9b_mmbench_evaluator = [dict(type="opencompass.DumpResults", save_path="work_dirs/otter-9b-mmbench.xlsx")]
......@@ -10,4 +10,5 @@ if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'):
from .llava import * # noqa: F401, F403
from .openflamingo import * # noqa: F401, F403
from .otter import * # noqa: F401, F403
from .visualglm import * # noqa: F401, F403
from typing import TYPE_CHECKING
from transformers.utils import (OptionalDependencyNotAvailable,
is_torch_available)
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
import mmengine
import torch
import torch.nn as nn
from opencompass.registry import MM_MODELS
from .Otter.models.otter.modeling_otter import OtterForConditionalGeneration
@MM_MODELS.register_module('otter-9b')
class Otter(nn.Module):
"""Inference code of OTTER.
Model details:
OTTER: a multi-modal model based on OpenFlamingo
(open-sourced version of DeepMind's Flamingo)
https://github.com/Luodian/Otter
Args:
model_path (str): The path of OTTER model
in Huggingface model hub format.
load_bit (str): The bit of OTTER model, can be "fp32" or "bf16".
"""
def __init__(self, model_path, load_bit, prompt_constructor,
post_processor) -> None:
super().__init__()
torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32
self.model = OtterForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch_dtype)
self.tokenizer = self.model.text_tokenizer
self.tokenizer.padding_side = 'left'
self.model_dtype = next(self.model.parameters()).dtype
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
elif self.mode == 'loss':
return self.loss(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def generate(self, batch):
inputs = self.prompt_constructor(batch)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples']
vision_x = image.unsqueeze(1).unsqueeze(0).to(dtype=self.model_dtype)
lang_x = self.model.text_tokenizer([prompt], return_tensors='pt')
bad_words_id = self.model.text_tokenizer(['User:', 'GPT:']).input_ids
generated_text = self.model.generate(
vision_x=vision_x.to(self.model.device),
lang_x=lang_x['input_ids'].to(self.model.device),
attention_mask=lang_x['attention_mask'].to(self.model.device),
do_sample=False,
max_new_tokens=512,
num_beams=3,
bad_words_ids=bad_words_id,
no_repeat_ngram_size=3,
)
for i, data_sample in enumerate(data_samples):
output_text = self.post_processor(generated_text[i],
self.model.text_tokenizer)
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples
import random
import re
import torch
class OTTERMMBenchPostProcessor:
""""Post processor for OTTER on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text)
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = (output_text.split('<answer>')[-1].lstrip().rstrip().
split('<|endofchunk|>')[0].lstrip().rstrip())
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
class OTTERCOCOCaptionPostProcessor:
""""Post processor for OTTER on COCO Caption."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = (output_text.split('<answer>')[-1].lstrip().rstrip().
split('<|endofchunk|>')[0].lstrip().rstrip())
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
class OTTERScienceQAPostProcessor:
""""Post processor for OTTER on ScienceQA."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = (output_text.split('<answer>')[-1].lstrip().rstrip().
split('<|endofchunk|>')[0].lstrip().rstrip())
pattern = re.compile(r'\(([A-Z])\)')
output_text = pattern.findall(output_text)
if len(output_text) == 0:
output_text = random.choice(['A', 'B', 'C', 'D'])
else:
output_text = output_text[0]
return output_text
class OTTERVQAPostProcessor:
""""Post processor for OTTER on VQA."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = (output_text.split('<answer>')[-1].lstrip().rstrip().
split('<|endofchunk|>')[0].lstrip().rstrip())
return output_text
class OTTERVSRPostProcessor:
""""Post processor for OTTER on VSR."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token, add_special_tokens=False)
pattern = r'yes|no|Yes|No'
output_text = re.findall(pattern, output_text)
if len(output_text) > 0:
output_text = output_text[0].lower()
return output_text
class OTTERMMEPostProcessor(OTTERMMBenchPostProcessor):
""""Post processor for OTTER on MME."""
def __init__(self) -> None:
super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
response = super().__call__(output_token, tokenizer)
# extract yes or no, copy from MME official evaluation script
prefix_pred_ans = response[:4].lower()
if 'yes' in prefix_pred_ans:
pred_label = 'yes'
elif 'no' in prefix_pred_ans:
pred_label = 'no'
else:
pred_label = 'other'
return pred_label
from typing import List
import torch
from mmpretrain.structures import DataSample
class OTTERMMBenchPromptConstructor:
"""Prompt constructor for OTTER on MMBench.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, user_label: str = '', model_label: str = '') -> None:
self.image_token = '<image>'
self.reply_token = '<answer>'
self.user_label = user_label
self.model_label = model_label
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
images = [image.unsqueeze(0) for image in inputs['inputs']]
data_samples = [data_sample for data_sample in inputs['data_samples']]
images = torch.cat(images, dim=0)
inputs = {'image': images, 'data_samples': data_samples}
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
data_sample = data_samples[0]
question = data_sample.get('question')
options = data_sample.get('options')
context = data_sample.get('context')
# e.g. <image>User: What is the color of the sky? A: Blue B: Red C: Green D: Yellow GPT:<answer> # noqa
if context is not None:
prompt = f'{self.image_token}{self.user_label} {context[i]} {question[i]} {options[i]} {self.model_label}:{self.reply_token}' # noqa
else:
prompt = f'{self.image_token}{self.user_label} {question[i]} {options[i]} {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERCOCOCaotionPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on COCO Caption."""
def _process(self, data_samples: List[DataSample]) -> str:
# e.g. <image>User: a photo of GPT:<answer> # noqa
prompt = f'{self.image_token}{self.user_label} a photo of {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERScienceQAPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on ScienceQA."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
'Question: ' + data_sample.get('question') + '\n'
for data_sample in data_samples
] # noqa
choices = [data_sample.get('choices') for data_sample in data_samples]
choices = [[
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choice)
] for choice in choices]
choices = [
'Choices: ' + ' '.join(choice) + '\n' for choice in choices
] # noqa
contexts = [
'Context: ' + data_sample.get('hint') + '\n'
for data_sample in data_samples
] # noqa
question = questions[0]
choice = choices[0]
context = contexts[0]
prompt = f'{self.image_token}{self.user_label} {context} {question} {choice} The answer is {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERVQAPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on VQA."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = f'{self.image_token}{self.user_label} {question}. Answer it with with few words. {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERVSRPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on VSR."""
def _process(self, data_samples: List[DataSample]) -> str:
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = f'{self.image_token}{self.user_label} {question}. Is the above description correct? Answer yes or no. {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERSEEDBenchPromptConstructor(OTTERMMBenchPromptConstructor):
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
question = questions[0]
prompt = f'{self.image_token}{self.user_label} {question} {self.model_label}:{self.reply_token}' # noqa
return prompt
class OTTERMMEPromptConstructor(OTTERMMBenchPromptConstructor):
"""Prompt constructor for OTTER on MME.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
question = data_samples[0].get('question')
prompt = f'{self.image_token}{self.user_label} {question} {self.model_label}:{self.reply_token}' # noqa
return prompt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment