Unverified Commit bd50bad8 authored by Yike Yuan's avatar Yike Yuan Committed by GitHub
Browse files

[Feat] Support mm models on public dataset and fix several issues. (#412)



* [Feat] Add public dataset support for visualglm, qwenvl, and flamingo

* [Fix] MMBench related changes.

* [Fix] Openflamingo inference.

* [Fix] Hide ckpt path.

* [Fix] Pre-commit.

---------
Co-authored-by: default avatarHaodong Duan <dhd.efz@gmail.com>
parent 7c2726c2
...@@ -53,9 +53,9 @@ class OTTERMMBenchPromptConstructor: ...@@ -53,9 +53,9 @@ class OTTERMMBenchPromptConstructor:
context = data_sample.get('context') context = data_sample.get('context')
# e.g. <image>User: What is the color of the sky? A: Blue B: Red C: Green D: Yellow GPT:<answer> # noqa # e.g. <image>User: What is the color of the sky? A: Blue B: Red C: Green D: Yellow GPT:<answer> # noqa
if context is not None: if context is not None:
prompt = f'{self.image_token}{self.user_label} {context[i]} {question[i]} {options[i]} {self.model_label}:{self.reply_token}' # noqa prompt = f'{self.image_token}{self.user_label} {context} {question} {options} {self.model_label}:{self.reply_token}' # noqa
else: else:
prompt = f'{self.image_token}{self.user_label} {question[i]} {options[i]} {self.model_label}:{self.reply_token}' # noqa prompt = f'{self.image_token}{self.user_label} {question} {options} {self.model_label}:{self.reply_token}' # noqa
return prompt return prompt
......
from .post_processor import QwenVLBasePostProcessor from .post_processor import QwenVLBasePostProcessor, QwenVLChatVSRPostProcessor
from .prompt_constructor import QwenVLMMBenchPromptConstructor from .prompt_constructor import (QwenVLChatPromptConstructor,
QwenVLChatScienceQAPromptConstructor,
QwenVLChatVQAPromptConstructor,
QwenVLMMBenchPromptConstructor)
from .qwen import QwenVLBase, QwenVLChat from .qwen import QwenVLBase, QwenVLChat
__all__ = [ __all__ = [
'QwenVLBase', 'QwenVLChat', 'QwenVLBasePostProcessor', 'QwenVLBase', 'QwenVLChat', 'QwenVLBasePostProcessor',
'QwenVLMMBenchPromptConstructor' 'QwenVLMMBenchPromptConstructor', 'QwenVLChatPromptConstructor',
'QwenVLChatVQAPromptConstructor', 'QwenVLChatVSRPostProcessor',
'QwenVLChatScienceQAPromptConstructor'
] ]
...@@ -14,3 +14,18 @@ class QwenVLBasePostProcessor: ...@@ -14,3 +14,18 @@ class QwenVLBasePostProcessor:
response = self.tokenizer.decode(pred)[input_len:] response = self.tokenizer.decode(pred)[input_len:]
response = response.replace('<|endoftext|>', '').strip() response = response.replace('<|endoftext|>', '').strip()
return response return response
class QwenVLChatVSRPostProcessor:
"""VSR post processor for Qwen-VL-Chat."""
def __init__(self) -> None:
pass
def __call__(self, response: str) -> str:
if 'yes' in response.lower():
return 'yes'
elif 'no' in response.lower():
return 'no'
else:
return 'unknown'
...@@ -7,7 +7,7 @@ class QwenVLMMBenchPromptConstructor: ...@@ -7,7 +7,7 @@ class QwenVLMMBenchPromptConstructor:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def __call__(self, inputs: dict) -> str: def __call__(self, inputs: dict) -> list:
data_samples = inputs['data_samples'] data_samples = inputs['data_samples']
assert len(data_samples) == 1 assert len(data_samples) == 1
data_sample = data_samples[0] data_sample = data_samples[0]
...@@ -27,3 +27,74 @@ class QwenVLMMBenchPromptConstructor: ...@@ -27,3 +27,74 @@ class QwenVLMMBenchPromptConstructor:
}, },
] ]
return format_input return format_input
class QwenVLChatPromptConstructor:
"""Prompt constructorfor Qwen-VL-Chat."""
def __init__(self, prompt='') -> None:
self.prompt = prompt
def __call__(self, inputs: dict) -> list:
assert len(inputs['data_samples']) == 1
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': self.prompt
},
]
return format_input
class QwenVLChatVQAPromptConstructor:
"""VQA prompt constructor for Qwen-VL-Chat."""
def __init__(self, prompt='') -> None:
self.prompt = prompt
def __call__(self, inputs: dict) -> list:
data_samples = inputs['data_samples']
assert len(data_samples) == 1
data_sample = data_samples[0]
question = data_sample.get('question')
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': question + self.prompt
},
]
return format_input
class QwenVLChatScienceQAPromptConstructor:
"""ScienceQA prompt constructor for Qwen-VL-Chat."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, prompt='') -> None:
self.prompt = prompt
def __call__(self, inputs: dict) -> list:
data_samples = inputs['data_samples']
assert len(data_samples) == 1
data_sample = data_samples[0]
question = data_sample.get('question')
choices = data_sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
choices = 'Choices: ' + ' '.join(choices) + '\n'
contexts = 'Context: ' + data_sample.get('hint')
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': contexts + question + choices + self.prompt
},
]
return format_input
...@@ -55,6 +55,8 @@ class QwenVLBase(nn.Module): ...@@ -55,6 +55,8 @@ class QwenVLBase(nn.Module):
if post_processor is not None: if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg( self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS) post_processor, MM_MODELS)
else:
self.post_processor = None
self.is_caption_task = is_caption_task self.is_caption_task = is_caption_task
self.model.transformer.forward = types.MethodType( self.model.transformer.forward = types.MethodType(
forward_hack, self.model.transformer) forward_hack, self.model.transformer)
...@@ -154,6 +156,9 @@ class QwenVLChat(QwenVLBase): ...@@ -154,6 +156,9 @@ class QwenVLChat(QwenVLBase):
verbose=False, verbose=False,
errors='replace') errors='replace')
if self.post_processor:
response = self.post_processor(response)
data_sample = batch['data_samples'][0] data_sample = batch['data_samples'][0]
if self.is_caption_task: if self.is_caption_task:
data_sample.pred_caption = response data_sample.pred_caption = response
......
...@@ -81,9 +81,7 @@ class VisualGLMBasePromptConstructor: ...@@ -81,9 +81,7 @@ class VisualGLMBasePromptConstructor:
data_samples = batch.pop('data_samples') data_samples = batch.pop('data_samples')
# generate text prompt # generate text prompt
img_prompt = '<img></img>' prompt = ['<img></img>' + self.prompt for i in range(images.shape[0])]
prompt = img_prompt + self.prompt
image_position = prompt.rfind('<img>') + 5
image_position = 5 image_position = 5
......
...@@ -43,7 +43,14 @@ class VisualGLM(nn.Module): ...@@ -43,7 +43,14 @@ class VisualGLM(nn.Module):
if gen_kwargs: if gen_kwargs:
self.gen_kwargs = gen_kwargs self.gen_kwargs = gen_kwargs
else: else:
self.gen_kwargs = dict() self.gen_kwargs = dict(
max_new_tokens=30,
num_beams=1,
do_sample=False,
repetition_penalty=1.0,
length_penalty=-1.0,
)
self.is_caption_task = is_caption_task self.is_caption_task = is_caption_task
def encode_by_tokenizer(self, multi_prompts, image_position): def encode_by_tokenizer(self, multi_prompts, image_position):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment