Unverified Commit bd50bad8 authored by Yike Yuan's avatar Yike Yuan Committed by GitHub
Browse files

[Feat] Support mm models on public dataset and fix several issues. (#412)



* [Feat] Add public dataset support for visualglm, qwenvl, and flamingo

* [Fix] MMBench related changes.

* [Fix] Openflamingo inference.

* [Fix] Hide ckpt path.

* [Fix] Pre-commit.

---------
Co-authored-by: default avatarHaodong Duan <dhd.efz@gmail.com>
parent 7c2726c2
from opencompass.multimodal.models.qwen import QwenVLChatScienceQAPromptConstructor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'gt_answer', 'choices', 'hint', 'lecture', 'solution'
])
]
dataset = dict(type='mmpretrain.ScienceQA',
data_root='./data/scienceqa',
split='val',
split_file='pid_splits.json',
ann_file='problems.json',
image_only=True,
data_prefix=dict(img_path='val'),
pipeline=val_pipeline)
qwen_scienceqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_scienceqa_model = dict(
type='qwen-vl-chat',
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
prompt_constructor=dict(type=QwenVLChatScienceQAPromptConstructor)
)
# evaluation settings
qwen_scienceqa_evaluator = [dict(type='mmpretrain.ScienceQAMetric')]
from opencompass.multimodal.models.qwen import QwenVLChatVQAPromptConstructor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.TextVQA',
data_root='data/textvqa',
ann_file='annotations/TextVQA_0.5.1_val.json',
pipeline=val_pipeline,
data_prefix='images/train_images',
)
qwen_textvqa_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_textvqa_model = dict(
type='qwen-vl-chat',
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
prompt_constructor=dict(type=QwenVLChatVQAPromptConstructor)
)
# evaluation settings
qwen_textvqa_evaluator = [dict(type='mmpretrain.VQAAcc')]
from opencompass.multimodal.models.qwen import QwenVLChatVQAPromptConstructor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VizWiz',
data_root='data/vizwiz/',
data_prefix='Images/val',
ann_file='Annotations/val.json',
pipeline=val_pipeline)
qwen_vizwiz_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_vizwiz_model = dict(
type='qwen-vl-chat',
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
prompt_constructor=dict(type=QwenVLChatVQAPromptConstructor)
)
# evaluation settings
qwen_vizwiz_evaluator = [dict(type='mmpretrain.VQAAcc')]
from opencompass.multimodal.models.qwen import QwenVLChatVQAPromptConstructor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(
type='mmpretrain.COCOVQA',
data_root='data/coco',
data_prefix='images/val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=val_pipeline)
qwen_vqav2_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_vqav2_model = dict(
type='qwen-vl-chat',
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
prompt_constructor=dict(type=QwenVLChatVQAPromptConstructor)
)
# evaluation settings
qwen_vqav2_evaluator = [dict(type='mmpretrain.VQAAcc')]
from opencompass.multimodal.models.qwen import QwenVLChatVQAPromptConstructor, QwenVLChatVSRPostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
)
]
dataset = dict(type='mmpretrain.VSR',
data_root='data/vsr/',
data_prefix='images/',
ann_file='annotations/test.json',
pipeline=val_pipeline)
qwen_vsr_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_vsr_model = dict(
type='qwen-vl-chat',
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
prompt_constructor=dict(type=QwenVLChatVQAPromptConstructor),
post_processor=dict(type=QwenVLChatVSRPostProcessor)
)
# evaluation settings
qwen_vsr_evaluator = [dict(type='mmpretrain.GQAAcc')]
...@@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict( ...@@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True, is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor), prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
post_processor=dict(type=VisualGLMBasePostProcessor) post_processor=dict(type=VisualGLMBasePostProcessor)
) )
......
...@@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict( ...@@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
is_caption_task=True, is_caption_task=True,
prompt_constructor=dict(type=VisualGLMBasePromptConstructor), prompt_constructor=dict(type=VisualGLMBasePromptConstructor, system_prompt='A photo of'),
post_processor=dict(type=VisualGLMBasePostProcessor) post_processor=dict(type=VisualGLMBasePostProcessor)
) )
......
...@@ -20,22 +20,23 @@ dataset = dict(type='opencompass.MMBenchDataset', ...@@ -20,22 +20,23 @@ dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv', data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline) pipeline=val_pipeline)
mmbench_dataloader = dict(batch_size=1, visualglm_mmbench_dataloader = dict(batch_size=1,
num_workers=4, num_workers=4,
dataset=dataset, dataset=dataset,
collate_fn=dict(type='pseudo_collate'), collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler', shuffle=False))
# model settings # model settings
visualglm_model = dict( visualglm_mmbench_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMMMBenchPromptConstructor), prompt_constructor=dict(type=VisualGLMMMBenchPromptConstructor),
post_processor=dict(type=VisualGLMBasePostProcessor) post_processor=dict(type=VisualGLMBasePostProcessor),
gen_kwargs=dict(max_new_tokens=50,num_beams=5,do_sample=False,repetition_penalty=1.0,length_penalty=-1.0)
) )
# evaluation settings # evaluation settings
mmbench_evaluator = [ visualglm_mmbench_evaluator = [
dict(type='opencompass.DumpResults', dict(type='opencompass.DumpResults',
save_path='work_dirs/visualglm-6b-mmbench.xlsx') save_path='work_dirs/visualglm-6b-mmbench.xlsx')
] ]
...@@ -26,7 +26,7 @@ dataset = dict(type='mmpretrain.ScienceQA', ...@@ -26,7 +26,7 @@ dataset = dict(type='mmpretrain.ScienceQA',
data_prefix=dict(img_path='val'), data_prefix=dict(img_path='val'),
pipeline=val_pipeline) pipeline=val_pipeline)
visualglm_vizwiz_dataloader = dict(batch_size=1, visualglm_scienceqa_dataloader = dict(batch_size=1,
num_workers=4, num_workers=4,
dataset=dataset, dataset=dataset,
collate_fn=dict(type='pseudo_collate'), collate_fn=dict(type='pseudo_collate'),
......
...@@ -33,7 +33,7 @@ visualglm_textvqa_dataloader = dict(batch_size=1, ...@@ -33,7 +33,7 @@ visualglm_textvqa_dataloader = dict(batch_size=1,
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler', shuffle=False))
# model settings # model settings
visualglm_model = dict( visualglm_textvqa_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
......
...@@ -31,7 +31,7 @@ visualglm_vizwiz_dataloader = dict(batch_size=1, ...@@ -31,7 +31,7 @@ visualglm_vizwiz_dataloader = dict(batch_size=1,
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler', shuffle=False))
# model settings # model settings
visualglm_model = dict( visualglm_vizwiz_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
......
...@@ -33,7 +33,7 @@ visualglm_vqav2_dataloader = dict(batch_size=1, ...@@ -33,7 +33,7 @@ visualglm_vqav2_dataloader = dict(batch_size=1,
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler', shuffle=False))
# model settings # model settings
visualglm_model = dict( visualglm_vqav2_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
......
...@@ -32,7 +32,7 @@ visualglm_vsr_dataloader = dict(batch_size=1, ...@@ -32,7 +32,7 @@ visualglm_vsr_dataloader = dict(batch_size=1,
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler', shuffle=False))
# model settings # model settings
visualglm_model = dict( visualglm_vsr_model = dict(
type='visualglm', type='visualglm',
pretrained_path='/path/to/visualglm', # or Huggingface repo id pretrained_path='/path/to/visualglm', # or Huggingface repo id
prompt_constructor=dict(type=VisualGLMVQAPromptConstructor), prompt_constructor=dict(type=VisualGLMVQAPromptConstructor),
......
...@@ -19,9 +19,6 @@ if osp.exists('opencompass/multimodal/models/mplug_owl/mPLUG-Owl'): ...@@ -19,9 +19,6 @@ if osp.exists('opencompass/multimodal/models/mplug_owl/mPLUG-Owl'):
from .mplug_owl import * # noqa: F401, F403 from .mplug_owl import * # noqa: F401, F403
from .openflamingo import * # noqa: F401, F403 from .openflamingo import * # noqa: F401, F403
from .otter import * # noqa: F401, F403
if osp.exists('opencompass/multimodal/models/otter/Otter'):
from .otter import * # noqa: F401, F403
from .qwen import * # noqa: F401, F403 from .qwen import * # noqa: F401, F403
from .visualglm import * # noqa: F401, F403 from .visualglm import * # noqa: F401, F403
from .openflamingo import OpenFlamingoInferencer from .openflamingo import OpenFlamingoInferencer
from .post_processor import OpenFlamingoVSRPostProcessor
from .prompt_constructor import (OpenFlamingoCaptionPromptConstructor,
OpenFlamingoMMBenchPromptConstructor,
OpenFlamingoScienceQAPromptConstructor,
OpenFlamingoVQAPromptConstructor)
__all__ = ['OpenFlamingoInferencer'] __all__ = [
'OpenFlamingoInferencer', 'OpenFlamingoMMBenchPromptConstructor',
'OpenFlamingoCaptionPromptConstructor', 'OpenFlamingoVQAPromptConstructor',
'OpenFlamingoScienceQAPromptConstructor', 'OpenFlamingoVSRPostProcessor'
]
import re
from typing import List, Optional, Union from typing import List, Optional, Union
import mmengine import mmengine
...@@ -21,17 +22,18 @@ class OpenFlamingoInferencer(Flamingo): ...@@ -21,17 +22,18 @@ class OpenFlamingoInferencer(Flamingo):
""" """
def __init__(self, def __init__(self,
prompt_constructor: Optional[dict] = None, prompt_constructor: dict,
post_processor: Optional[dict] = None, post_processor: Optional[dict] = None,
mode: str = 'generation', mode: str = 'generation',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if prompt_constructor is not None:
self.prompt_constructor = mmengine.registry.build_from_cfg( self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS) prompt_constructor, MM_MODELS)
if post_processor is not None: if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg( self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS) post_processor, MM_MODELS)
else:
self.post_processor = None
self.mode = mode self.mode = mode
def preprocess_text(self, data_samples: List[DataSample], def preprocess_text(self, data_samples: List[DataSample],
...@@ -46,16 +48,7 @@ class OpenFlamingoInferencer(Flamingo): ...@@ -46,16 +48,7 @@ class OpenFlamingoInferencer(Flamingo):
Returns: Returns:
List[DataSample]: Return list of data samples. List[DataSample]: Return list of data samples.
""" """
prompts = [] prompts = self.prompt_constructor(data_samples)
for sample in data_samples:
question = sample.get('question')
option = sample.get('options')
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
if data_samples[0].get('context') is not None:
prompt = sample.get('context') + ' ' + prompt
prompts.append(prompt)
self.tokenizer.padding_side = 'left' self.tokenizer.padding_side = 'left'
input_text = self.tokenizer( input_text = self.tokenizer(
...@@ -67,6 +60,42 @@ class OpenFlamingoInferencer(Flamingo): ...@@ -67,6 +60,42 @@ class OpenFlamingoInferencer(Flamingo):
).to(device) ).to(device)
return input_text return input_text
def post_process(
self, outputs: torch.Tensor,
data_samples: Optional[List[DataSample]]) -> List[DataSample]:
"""Perform post process for outputs for different task.
Args:
outputs (torch.Tensor): The generated outputs.
data_samples (List[DataSample], optional): The annotation
data of every samples.
Returns:
List[DataSample]: Return list of data samples.
"""
outputs = self.tokenizer.batch_decode(outputs,
skip_special_tokens=True)
if data_samples is None:
data_samples = [DataSample() for _ in range(len(outputs))]
for output, data_sample in zip(outputs, data_samples):
# remove text pattern
if self.task == 'caption':
data_sample.pred_caption = re.split('Output', output,
1)[0].replace('"', '')
if self.post_processor:
data_sample.pred_caption = self.post_processor(
data_sample.pred_caption)
elif self.task == 'vqa':
data_sample.pred_answer = re.split('Question|Answer', output,
1)[0]
if self.post_processor:
data_sample.pred_answer = self.post_processor(
data_sample.pred_answer)
return data_samples
def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]: def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]:
if self.mode == 'generation': if self.mode == 'generation':
......
class OpenFlamingoVSRPostProcessor:
"""VSR post processor for Openflamingo."""
def __init__(self) -> None:
pass
def __call__(self, raw_response: str) -> str:
if 'yes' in raw_response.lower():
return 'yes'
elif 'no' in raw_response.lower():
return 'no'
else:
return 'unknown'
from typing import Optional
from mmpretrain.structures import DataSample
class OpenFlamingoMMBenchPromptConstructor:
"""MMBench prompt constructor for OpenFlamingo."""
def __init__(self) -> None:
pass
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
prompts = []
question = sample.get('question')
option = sample.get('options')
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
if sample.get('context') is not None:
prompt = sample.get('context') + ' ' + prompt
prompts.append(prompt)
return prompts
class OpenFlamingoCaptionPromptConstructor:
"""Caption prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Output:A child holding a flowered umbrella and petting a yak.<|endofchunk|>' # noqa
'Output:The child is holding a brush close to his mouth.<|endofchunk|>' # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
prompts = []
prompt = '<image>Output:'
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoVQAPromptConstructor:
"""VQA prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Question:Is the sky dark? Short Answer:yes<|endofchunk|>' # noqa: E501
'Question:What is on the white wall? Short Answer:pipe<|endofchunk|>' # noqa: E501
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
prompts = []
for sample in data_samples:
question = sample.get('question')
prompt = '<image>Question:{} Short Answer:'.format(question)
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoScienceQAPromptConstructor:
"""ScienceQA prompt constructor for OpenFlamingo."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
"Context:Question:Which of these states is farthest north? Choices:['(A) West Virginia' '(B) Louisiana' '(C) Arizona' '(D) Oklahoma'] Answer with a single character: A<|endofchunk|>" # noqa
'Context:The diagrams below show two pure samples of gas in identical closed, rigid containers. Each colored ball represents one gas particle. Both samples have the same number of particles.' # noqa
"Question:Compare the average kinetic energies of the particles in each sample. Which sample has the higher temperature? Choices:'[(A) neither' '(B) sample A' '(C) sample B'] Answer with a single character: C<|endofchunk|>" # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
question = sample.get('question')
choices = sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
hint = sample.get('hint')
prompts = []
prompt = '<image>Context:{} Question:{} Choices:{}'.format(
hint, question, choices)
prompt += ' Answer with a single character:'
prompts.append(self.shot_prompt + prompt)
return prompts
...@@ -9,3 +9,11 @@ if TYPE_CHECKING: ...@@ -9,3 +9,11 @@ if TYPE_CHECKING:
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
from .otter import Otter
from .post_processor import OTTERMMBenchPostProcessor
from .prompt_constructor import OTTERMMBenchPromptConstructor
__all__ = [
'Otter', 'OTTERMMBenchPromptConstructor', 'OTTERMMBenchPostProcessor'
]
import importlib
import mmengine import mmengine
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.device import get_device
from opencompass.registry import MM_MODELS from opencompass.registry import MM_MODELS
from .Otter.models.otter.modeling_otter import OtterForConditionalGeneration
@MM_MODELS.register_module('otter-9b') @MM_MODELS.register_module('otter-9b')
class Otter(nn.Module): class Otter(nn.Module):
...@@ -19,14 +20,20 @@ class Otter(nn.Module): ...@@ -19,14 +20,20 @@ class Otter(nn.Module):
model_path (str): The path of OTTER model model_path (str): The path of OTTER model
in Huggingface model hub format. in Huggingface model hub format.
load_bit (str): The bit of OTTER model, can be "fp32" or "bf16". load_bit (str): The bit of OTTER model, can be "fp32" or "bf16".
mode (str): The mode of inference. Defaults to 'generation'.
""" """
def __init__(self, model_path, load_bit, prompt_constructor, def __init__(self,
post_processor) -> None: model_path,
load_bit,
prompt_constructor,
post_processor,
mode='generation') -> None:
super().__init__() super().__init__()
torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32 torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32
self.model = OtterForConditionalGeneration.from_pretrained( otter_ai = importlib.import_module('otter_ai')
model_path, torch_dtype=torch_dtype) self.model = otter_ai.OtterForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch_dtype, device_map=get_device())
self.tokenizer = self.model.text_tokenizer self.tokenizer = self.model.text_tokenizer
self.tokenizer.padding_side = 'left' self.tokenizer.padding_side = 'left'
self.model_dtype = next(self.model.parameters()).dtype self.model_dtype = next(self.model.parameters()).dtype
...@@ -35,6 +42,7 @@ class Otter(nn.Module): ...@@ -35,6 +42,7 @@ class Otter(nn.Module):
if post_processor is not None: if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg( self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS) post_processor, MM_MODELS)
self.mode = mode
def forward(self, batch): def forward(self, batch):
if self.mode == 'generation': if self.mode == 'generation':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment