Unverified Commit a205629f authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Refactor input and output (#176)

* [Feature]: Refactor input and output

* [Feature]: Update tasks
parent 876ade71
from opencompass.multimodal.models.minigpt_4 import (
MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
# dataloader settings # dataloader settings
val_pipeline = [ val_pipeline = [
dict(type='mmpretrain.torchvision/Resize', dict(type='mmpretrain.torchvision/Resize',
...@@ -9,8 +12,8 @@ val_pipeline = [ ...@@ -9,8 +12,8 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)), std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', dict(type='mmpretrain.PackInputs',
algorithm_keys=[ algorithm_keys=[
'question', 'category', 'l2-category', 'context', 'question', 'category', 'l2-category', 'context', 'index',
'index', 'options_dict', 'options', 'split' 'options_dict', 'options', 'split'
]) ])
] ]
...@@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1, ...@@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
# model settings # model settings
minigpt_4_model = dict( minigpt_4_model = dict(
type='minigpt-4-mmbench', type='minigpt-4-mmbench',
low_resource=True, low_resource=False,
llama_model='/path/to/vicuna', llama_model='/path/to/vicuna-7b/',
sys_prompt= # noqa: E251 prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
'###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n' image_prompt='###Human: <Img><ImageHere></Img>',
) reply_prompt='###Assistant:'),
post_processor=dict(type=MiniGPT4PostProcessor))
# evaluation settings # evaluation settings
minigpt_4_evaluator = [ minigpt_4_evaluator = [
...@@ -39,4 +43,4 @@ minigpt_4_evaluator = [ ...@@ -39,4 +43,4 @@ minigpt_4_evaluator = [
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx') save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
] ]
minigpt_4_load_from = '/path/to/minigpt-4' # noqa minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
...@@ -10,6 +10,6 @@ models = [minigpt_4_model] ...@@ -10,6 +10,6 @@ models = [minigpt_4_model]
datasets = [minigpt_4_dataloader] datasets = [minigpt_4_dataloader]
evaluators = [minigpt_4_evaluator] evaluators = [minigpt_4_evaluator]
load_froms = [minigpt_4_load_from] load_froms = [minigpt_4_load_from]
num_gpus = 1 num_gpus = 8
num_procs = 1 num_procs = 8
launcher = 'slurm' launcher = 'pytorch'
from .minigpt_4 import MiniGPT4MMBench from .minigpt_4 import MiniGPT4MMBench
from .post_processor import MiniGPT4PostProcessor
from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
__all__ = ['MiniGPT4MMBench'] __all__ = [
'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
'MiniGPT4MMBenchPromptConstructor'
]
import os import os
import re
import sys import sys
import mmengine
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.device import get_device from mmengine.device import get_device
...@@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4): ...@@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
Args: Args:
llama_model (str): The path of vicuna path. llama_model (str): The path of vicuna path.
sys_prompt (str): The prompt added to the beginning prompt_constructor (dict): The config of prompt constructor.
of each query. Defaults to ''. post_processor (dict): The config of post processor.
low_resource (bool): Whether loaded in low precision. low_resource (bool): Whether loaded in low precision.
Defaults to False. Defaults to False.
""" """
def __init__(self, def __init__(self,
llama_model: str, llama_model: str,
sys_prompt: str = '', prompt_constructor: dict,
post_processor: dict,
low_resource: bool = False) -> None: low_resource: bool = False) -> None:
super().__init__(llama_model=llama_model, low_resource=low_resource) super().__init__(llama_model=llama_model, low_resource=low_resource)
...@@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4): ...@@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
] ]
self.stopping_criteria = StoppingCriteriaList( self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)]) [StoppingCriteriaSub(stops=stop_words_ids)])
self.sys_prompt = sys_prompt self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def encode_img(self, image): def encode_img(self, image):
device = image.device device = image.device
...@@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4): ...@@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
def generate(self, batch): def generate(self, batch):
inputs = self.pack_inputs(batch) inputs = self.pack_inputs(batch)
image = inputs.pop('image') inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples'] data_samples = inputs['data_samples']
samples = {'image': image}
question = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
samples.update({'question': question[0]})
samples.update({'options': options[0]})
if data_samples[0].get('context') is not None:
context = [
data_sample.get('context') for data_sample in data_samples
]
samples.update({'context': context})
data_sample = data_samples[0]
img_prompt = '###Human: <Img><ImageHere></Img> '
if 'context' in samples:
context_prompt = samples['context'][0]
question = samples['question']
options = samples['options']
if 'context' in samples:
prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
else:
prompt = img_prompt + ' ' + question + ' ' + options
# prompt = self.sys_prompt + prompt
prompt = prompt + '###Assistant:'
image = samples['image']
img_embeds, _ = self.encode_img(image)
# The main process of generation
img_embeds, _ = self.encode_img(image)
prompt_segs = prompt.split('<ImageHere>') prompt_segs = prompt.split('<ImageHere>')
prompt_seg_tokens = [ prompt_seg_tokens = [
self.llama_tokenizer(seg, self.llama_tokenizer(seg,
...@@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4): ...@@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
stopping_criteria=self.stopping_criteria, stopping_criteria=self.stopping_criteria,
num_return_sequences=1) num_return_sequences=1)
output_token = outputs[0] for i, data_sample in enumerate(data_samples):
if output_token[0] == 0: output_token = outputs[i]
output_token = output_token[1:] output_text = self.post_processor(output_token,
if output_token[0] == 1: self.llama_tokenizer)
output_token = output_token[1:]
output_text = self.llama_tokenizer.decode(output_token,
add_special_tokens=False)
output_text = self.post_process(output_text)
data_sample.pred_answer = output_text data_sample.pred_answer = output_text
return data_sample data_samples[i] = data_sample
return data_samples
def post_process(self, output_text):
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
import re
import torch
class MiniGPT4PostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
if output_token[0] == 0:
output_token = output_token[1:]
if output_token[0] == 1:
output_token = output_token[1:]
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text)
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
from typing import List
from mmpretrain.structures import DataSample
class MiniGPT4MMBenchPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import os.path as osp import os.path as osp
import random import random
import time import time
from typing import Sequence from typing import List, Sequence
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -78,6 +78,22 @@ class MultimodalInferTask: ...@@ -78,6 +78,22 @@ class MultimodalInferTask:
return osp.join(model_name, return osp.join(model_name,
f'{dataset_name}-{evaluator_name}.{file_extension}') f'{dataset_name}-{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return [
osp.join(model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
]
def get_command(self, cfg_path, template): def get_command(self, cfg_path, template):
"""Get the command template for the task. """Get the command template for the task.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment