"vscode:/vscode.git/clone" did not exist on "d2f16726905f53952fc744147802076f6a1d8711"
Unverified Commit dc6e54f6 authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Verify the acc of these public datasets (#269)

* [Feature]: Refactor public dataset eval

* [Feature]: Verify public dataset acc
parent 3f37c40a
...@@ -35,7 +35,8 @@ minigpt_4_coco_caption_model = dict( ...@@ -35,7 +35,8 @@ minigpt_4_coco_caption_model = dict(
type='minigpt-4', type='minigpt-4',
low_resource=False, low_resource=False,
img_size=384, img_size=384,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
is_caption_task=True,
prompt_constructor=dict(type=MiniGPT4COCOCaotionPromptConstructor, prompt_constructor=dict(type=MiniGPT4COCOCaotionPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -24,19 +24,20 @@ dataset = dict(type='mmpretrain.Flickr30kCaption', ...@@ -24,19 +24,20 @@ dataset = dict(type='mmpretrain.Flickr30kCaption',
split='val', split='val',
pipeline=val_pipeline) pipeline=val_pipeline)
minigpt_4_flickr30k_dataloader = dict( minigpt_4_flickr30k_dataloader = dict(batch_size=1,
batch_size=1,
num_workers=4, num_workers=4,
dataset=dataset, dataset=dataset,
collate_fn=dict(type='pseudo_collate'), collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings # model settings
minigpt_4_flickr30k_model = dict( minigpt_4_flickr30k_model = dict(
type='minigpt-4', type='minigpt-4',
low_resource=False, low_resource=False,
img_size=384, img_size=384,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
is_caption_task=True,
prompt_constructor=dict(type=MiniGPT4COCOCaotionPromptConstructor, prompt_constructor=dict(type=MiniGPT4COCOCaotionPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
...@@ -46,7 +47,7 @@ minigpt_4_flickr30k_model = dict( ...@@ -46,7 +47,7 @@ minigpt_4_flickr30k_model = dict(
minigpt_4_flickr30k_evaluator = [ minigpt_4_flickr30k_evaluator = [
dict( dict(
type='mmpretrain.COCOCaption', type='mmpretrain.COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json', ann_file='data/flickr30k/annotations/flickr30k_val_gt.json',
) # noqa ) # noqa
] ]
......
...@@ -39,7 +39,7 @@ minigpt_4_gqa_model = dict(type='minigpt-4', ...@@ -39,7 +39,7 @@ minigpt_4_gqa_model = dict(type='minigpt-4',
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict( prompt_constructor=dict(
type=MiniGPT4VQAPromptConstructor, type=MiniGPT4VQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
......
...@@ -41,7 +41,7 @@ minigpt_4_ocr_vqa_model = dict( ...@@ -41,7 +41,7 @@ minigpt_4_ocr_vqa_model = dict(
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor, prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -43,7 +43,7 @@ minigpt_4_ok_vqa_model = dict( ...@@ -43,7 +43,7 @@ minigpt_4_ok_vqa_model = dict(
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor, prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -40,7 +40,7 @@ minigpt_4_scienceqa_model = dict( ...@@ -40,7 +40,7 @@ minigpt_4_scienceqa_model = dict(
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4ScienceQAPromptConstructor, prompt_constructor=dict(type=MiniGPT4ScienceQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -43,7 +43,7 @@ minigpt_4_textvqa_model = dict( ...@@ -43,7 +43,7 @@ minigpt_4_textvqa_model = dict(
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor, prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -40,7 +40,7 @@ minigpt_4_vizwiz_model = dict( ...@@ -40,7 +40,7 @@ minigpt_4_vizwiz_model = dict(
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor, prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -43,7 +43,7 @@ minigpt_4_vqav2_model = dict( ...@@ -43,7 +43,7 @@ minigpt_4_vqav2_model = dict(
low_resource=False, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor, prompt_constructor=dict(type=MiniGPT4VQAPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -37,10 +37,10 @@ minigpt_4_vsr_dataloader = dict(batch_size=1, ...@@ -37,10 +37,10 @@ minigpt_4_vsr_dataloader = dict(batch_size=1,
# model settings # model settings
minigpt_4_vsr_model = dict( minigpt_4_vsr_model = dict(
type='minigpt-4', type='minigpt-4',
low_resource=True, low_resource=False,
img_size=224, img_size=224,
max_length=10, max_length=10,
llama_model='/path/to/vicuna-7b/', llama_model='/path/to/vicuna_weights_7b/',
prompt_constructor=dict(type=MiniGPT4VSRPromptConstructor, prompt_constructor=dict(type=MiniGPT4VSRPromptConstructor,
image_prompt='###Human: <Img><ImageHere></Img>', image_prompt='###Human: <Img><ImageHere></Img>',
reply_prompt='###Assistant:'), reply_prompt='###Assistant:'),
......
...@@ -50,6 +50,8 @@ class MiniGPT4Inferencer(MiniGPT4): ...@@ -50,6 +50,8 @@ class MiniGPT4Inferencer(MiniGPT4):
img_size (int): The size of image. Defaults to 224. img_size (int): The size of image. Defaults to 224.
low_resource (bool): Whether loaded in low precision. low_resource (bool): Whether loaded in low precision.
Defaults to False. Defaults to False.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
""" """
def __init__(self, def __init__(self,
...@@ -60,6 +62,7 @@ class MiniGPT4Inferencer(MiniGPT4): ...@@ -60,6 +62,7 @@ class MiniGPT4Inferencer(MiniGPT4):
max_length: int = 30, max_length: int = 30,
img_size: int = 224, img_size: int = 224,
low_resource: bool = False, low_resource: bool = False,
is_caption_task: bool = False,
mode: str = 'generation', mode: str = 'generation',
n_segments: int = 1) -> None: n_segments: int = 1) -> None:
super().__init__(llama_model=llama_model, super().__init__(llama_model=llama_model,
...@@ -83,6 +86,7 @@ class MiniGPT4Inferencer(MiniGPT4): ...@@ -83,6 +86,7 @@ class MiniGPT4Inferencer(MiniGPT4):
post_processor, MM_MODELS) post_processor, MM_MODELS)
self.do_sample = do_sample self.do_sample = do_sample
self.max_length = max_length self.max_length = max_length
self.is_caption_task = is_caption_task
def forward(self, batch): def forward(self, batch):
if self.mode == 'generation': if self.mode == 'generation':
...@@ -193,6 +197,9 @@ class MiniGPT4Inferencer(MiniGPT4): ...@@ -193,6 +197,9 @@ class MiniGPT4Inferencer(MiniGPT4):
output_token = outputs[i] output_token = outputs[i]
output_text = self.post_processor(output_token, output_text = self.post_processor(output_token,
self.llama_tokenizer) self.llama_tokenizer)
if self.is_caption_task:
data_sample.pred_caption = output_text
else:
data_sample.pred_answer = output_text data_sample.pred_answer = output_text
data_samples[i] = data_sample data_samples[i] = data_sample
return data_samples return data_samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment