Unverified Commit bd50bad8 authored by Yike Yuan's avatar Yike Yuan Committed by GitHub
Browse files

[Feat] Support mm models on public dataset and fix several issues. (#412)



* [Feat] Add public dataset support for visualglm, qwenvl, and flamingo

* [Fix] MMBench related changes.

* [Fix] Openflamingo inference.

* [Fix] Hide ckpt path.

* [Fix] Pre-commit.

---------
Co-authored-by: default avatarHaodong Duan <dhd.efz@gmail.com>
parent 7c2726c2
......@@ -53,9 +53,9 @@ class OTTERMMBenchPromptConstructor:
context = data_sample.get('context')
# e.g. <image>User: What is the color of the sky? A: Blue B: Red C: Green D: Yellow GPT:<answer> # noqa
if context is not None:
prompt = f'{self.image_token}{self.user_label} {context[i]} {question[i]} {options[i]} {self.model_label}:{self.reply_token}' # noqa
prompt = f'{self.image_token}{self.user_label} {context} {question} {options} {self.model_label}:{self.reply_token}' # noqa
else:
prompt = f'{self.image_token}{self.user_label} {question[i]} {options[i]} {self.model_label}:{self.reply_token}' # noqa
prompt = f'{self.image_token}{self.user_label} {question} {options} {self.model_label}:{self.reply_token}' # noqa
return prompt
......
from .post_processor import QwenVLBasePostProcessor
from .prompt_constructor import QwenVLMMBenchPromptConstructor
from .post_processor import QwenVLBasePostProcessor, QwenVLChatVSRPostProcessor
from .prompt_constructor import (QwenVLChatPromptConstructor,
QwenVLChatScienceQAPromptConstructor,
QwenVLChatVQAPromptConstructor,
QwenVLMMBenchPromptConstructor)
from .qwen import QwenVLBase, QwenVLChat
__all__ = [
'QwenVLBase', 'QwenVLChat', 'QwenVLBasePostProcessor',
'QwenVLMMBenchPromptConstructor'
'QwenVLMMBenchPromptConstructor', 'QwenVLChatPromptConstructor',
'QwenVLChatVQAPromptConstructor', 'QwenVLChatVSRPostProcessor',
'QwenVLChatScienceQAPromptConstructor'
]
......@@ -14,3 +14,18 @@ class QwenVLBasePostProcessor:
response = self.tokenizer.decode(pred)[input_len:]
response = response.replace('<|endoftext|>', '').strip()
return response
class QwenVLChatVSRPostProcessor:
"""VSR post processor for Qwen-VL-Chat."""
def __init__(self) -> None:
pass
def __call__(self, response: str) -> str:
if 'yes' in response.lower():
return 'yes'
elif 'no' in response.lower():
return 'no'
else:
return 'unknown'
......@@ -7,7 +7,7 @@ class QwenVLMMBenchPromptConstructor:
def __init__(self) -> None:
pass
def __call__(self, inputs: dict) -> str:
def __call__(self, inputs: dict) -> list:
data_samples = inputs['data_samples']
assert len(data_samples) == 1
data_sample = data_samples[0]
......@@ -27,3 +27,74 @@ class QwenVLMMBenchPromptConstructor:
},
]
return format_input
class QwenVLChatPromptConstructor:
"""Prompt constructorfor Qwen-VL-Chat."""
def __init__(self, prompt='') -> None:
self.prompt = prompt
def __call__(self, inputs: dict) -> list:
assert len(inputs['data_samples']) == 1
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': self.prompt
},
]
return format_input
class QwenVLChatVQAPromptConstructor:
"""VQA prompt constructor for Qwen-VL-Chat."""
def __init__(self, prompt='') -> None:
self.prompt = prompt
def __call__(self, inputs: dict) -> list:
data_samples = inputs['data_samples']
assert len(data_samples) == 1
data_sample = data_samples[0]
question = data_sample.get('question')
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': question + self.prompt
},
]
return format_input
class QwenVLChatScienceQAPromptConstructor:
"""ScienceQA prompt constructor for Qwen-VL-Chat."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, prompt='') -> None:
self.prompt = prompt
def __call__(self, inputs: dict) -> list:
data_samples = inputs['data_samples']
assert len(data_samples) == 1
data_sample = data_samples[0]
question = data_sample.get('question')
choices = data_sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
choices = 'Choices: ' + ' '.join(choices) + '\n'
contexts = 'Context: ' + data_sample.get('hint')
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': contexts + question + choices + self.prompt
},
]
return format_input
......@@ -55,6 +55,8 @@ class QwenVLBase(nn.Module):
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
else:
self.post_processor = None
self.is_caption_task = is_caption_task
self.model.transformer.forward = types.MethodType(
forward_hack, self.model.transformer)
......@@ -154,6 +156,9 @@ class QwenVLChat(QwenVLBase):
verbose=False,
errors='replace')
if self.post_processor:
response = self.post_processor(response)
data_sample = batch['data_samples'][0]
if self.is_caption_task:
data_sample.pred_caption = response
......
......@@ -81,9 +81,7 @@ class VisualGLMBasePromptConstructor:
data_samples = batch.pop('data_samples')
# generate text prompt
img_prompt = '<img></img>'
prompt = img_prompt + self.prompt
image_position = prompt.rfind('<img>') + 5
prompt = ['<img></img>' + self.prompt for i in range(images.shape[0])]
image_position = 5
......
......@@ -43,7 +43,14 @@ class VisualGLM(nn.Module):
if gen_kwargs:
self.gen_kwargs = gen_kwargs
else:
self.gen_kwargs = dict()
self.gen_kwargs = dict(
max_new_tokens=30,
num_beams=1,
do_sample=False,
repetition_penalty=1.0,
length_penalty=-1.0,
)
self.is_caption_task = is_caption_task
def encode_by_tokenizer(self, multi_prompts, image_position):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment