Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
a205629f
Unverified
Commit
a205629f
authored
Aug 10, 2023
by
Yuan Liu
Committed by
GitHub
Aug 10, 2023
Browse files
[Feature]: Refactor input and output (#176)
* [Feature]: Refactor input and output * [Feature]: Update tasks
parent
876ade71
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
148 additions
and
70 deletions
+148
-70
configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
+12
-8
configs/multimodal/tasks.py
configs/multimodal/tasks.py
+3
-3
opencompass/multimodal/models/minigpt_4/__init__.py
opencompass/multimodal/models/minigpt_4/__init__.py
+6
-1
opencompass/multimodal/models/minigpt_4/minigpt_4.py
opencompass/multimodal/models/minigpt_4/minigpt_4.py
+21
-57
opencompass/multimodal/models/minigpt_4/post_processor.py
opencompass/multimodal/models/minigpt_4/post_processor.py
+34
-0
opencompass/multimodal/models/minigpt_4/prompt_constructor.py
...compass/multimodal/models/minigpt_4/prompt_constructor.py
+55
-0
opencompass/tasks/mm_infer.py
opencompass/tasks/mm_infer.py
+17
-1
No files found.
configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
View file @
a205629f
from
opencompass.multimodal.models.minigpt_4
import
(
MiniGPT4MMBenchPromptConstructor
,
MiniGPT4PostProcessor
)
# dataloader settings
# dataloader settings
val_pipeline
=
[
val_pipeline
=
[
dict
(
type
=
'mmpretrain.torchvision/Resize'
,
dict
(
type
=
'mmpretrain.torchvision/Resize'
,
...
@@ -9,8 +12,8 @@ val_pipeline = [
...
@@ -9,8 +12,8 @@ val_pipeline = [
std
=
(
0.26862954
,
0.26130258
,
0.27577711
)),
std
=
(
0.26862954
,
0.26130258
,
0.27577711
)),
dict
(
type
=
'mmpretrain.PackInputs'
,
dict
(
type
=
'mmpretrain.PackInputs'
,
algorithm_keys
=
[
algorithm_keys
=
[
'question'
,
'category'
,
'l2-category'
,
'context'
,
'question'
,
'category'
,
'l2-category'
,
'context'
,
'index'
,
'index'
,
'options_dict'
,
'options'
,
'split'
'options_dict'
,
'options'
,
'split'
])
])
]
]
...
@@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
...
@@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
# model settings
# model settings
minigpt_4_model
=
dict
(
minigpt_4_model
=
dict
(
type
=
'minigpt-4-mmbench'
,
type
=
'minigpt-4-mmbench'
,
low_resource
=
True
,
low_resource
=
False
,
llama_model
=
'/path/to/vicuna'
,
llama_model
=
'/path/to/vicuna-7b/'
,
sys_prompt
=
# noqa: E251
prompt_constructor
=
dict
(
type
=
MiniGPT4MMBenchPromptConstructor
,
'###Human: What is the capital of China? There are several options:
\n
A. Beijing
\n
B. Shanghai
\n
C. Guangzhou
\n
D. Shenzhen
\n
###Assistant: A
\n
'
image_prompt
=
'###Human: <Img><ImageHere></Img>'
,
)
reply_prompt
=
'###Assistant:'
),
post_processor
=
dict
(
type
=
MiniGPT4PostProcessor
))
# evaluation settings
# evaluation settings
minigpt_4_evaluator
=
[
minigpt_4_evaluator
=
[
...
@@ -39,4 +43,4 @@ minigpt_4_evaluator = [
...
@@ -39,4 +43,4 @@ minigpt_4_evaluator = [
save_path
=
'work_dirs/minigpt-4-7b-mmbench.xlsx'
)
save_path
=
'work_dirs/minigpt-4-7b-mmbench.xlsx'
)
]
]
minigpt_4_load_from
=
'/path/to/
minigpt-4
'
# noqa
minigpt_4_load_from
=
'/path/to/
prerained_minigpt4_7b.pth
'
# noqa
configs/multimodal/tasks.py
View file @
a205629f
...
@@ -10,6 +10,6 @@ models = [minigpt_4_model]
...
@@ -10,6 +10,6 @@ models = [minigpt_4_model]
datasets
=
[
minigpt_4_dataloader
]
datasets
=
[
minigpt_4_dataloader
]
evaluators
=
[
minigpt_4_evaluator
]
evaluators
=
[
minigpt_4_evaluator
]
load_froms
=
[
minigpt_4_load_from
]
load_froms
=
[
minigpt_4_load_from
]
num_gpus
=
1
num_gpus
=
8
num_procs
=
1
num_procs
=
8
launcher
=
'
slurm
'
launcher
=
'
pytorch
'
opencompass/multimodal/models/minigpt_4/__init__.py
View file @
a205629f
from
.minigpt_4
import
MiniGPT4MMBench
from
.minigpt_4
import
MiniGPT4MMBench
from
.post_processor
import
MiniGPT4PostProcessor
from
.prompt_constructor
import
MiniGPT4MMBenchPromptConstructor
__all__
=
[
'MiniGPT4MMBench'
]
__all__
=
[
'MiniGPT4MMBench'
,
'MiniGPT4PostProcessor'
,
'MiniGPT4MMBenchPromptConstructor'
]
opencompass/multimodal/models/minigpt_4/minigpt_4.py
View file @
a205629f
import
os
import
os
import
re
import
sys
import
sys
import
mmengine
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmengine.device
import
get_device
from
mmengine.device
import
get_device
...
@@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
...
@@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
Args:
Args:
llama_model (str): The path of vicuna path.
llama_model (str): The path of vicuna path.
sys_
prompt
(str): The prompt added to the beginning
prompt
_constructor (dict): The config of prompt constructor.
of each query. Defaults to ''
.
post_processor (dict): The config of post processor
.
low_resource (bool): Whether loaded in low precision.
low_resource (bool): Whether loaded in low precision.
Defaults to False.
Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
llama_model
:
str
,
llama_model
:
str
,
sys_prompt
:
str
=
''
,
prompt_constructor
:
dict
,
post_processor
:
dict
,
low_resource
:
bool
=
False
)
->
None
:
low_resource
:
bool
=
False
)
->
None
:
super
().
__init__
(
llama_model
=
llama_model
,
low_resource
=
low_resource
)
super
().
__init__
(
llama_model
=
llama_model
,
low_resource
=
low_resource
)
...
@@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
...
@@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
]
]
self
.
stopping_criteria
=
StoppingCriteriaList
(
self
.
stopping_criteria
=
StoppingCriteriaList
(
[
StoppingCriteriaSub
(
stops
=
stop_words_ids
)])
[
StoppingCriteriaSub
(
stops
=
stop_words_ids
)])
self
.
sys_prompt
=
sys_prompt
self
.
prompt_constructor
=
mmengine
.
registry
.
build_from_cfg
(
prompt_constructor
,
MM_MODELS
)
self
.
post_processor
=
mmengine
.
registry
.
build_from_cfg
(
post_processor
,
MM_MODELS
)
def
encode_img
(
self
,
image
):
def
encode_img
(
self
,
image
):
device
=
image
.
device
device
=
image
.
device
...
@@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
...
@@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
def
generate
(
self
,
batch
):
def
generate
(
self
,
batch
):
inputs
=
self
.
pack_inputs
(
batch
)
inputs
=
self
.
pack_inputs
(
batch
)
image
=
inputs
.
pop
(
'image'
)
inputs
=
self
.
prompt_constructor
(
inputs
)
image
=
inputs
[
'image'
]
prompt
=
inputs
[
'prompt'
]
data_samples
=
inputs
[
'data_samples'
]
data_samples
=
inputs
[
'data_samples'
]
samples
=
{
'image'
:
image
}
question
=
[
data_sample
.
get
(
'question'
)
for
data_sample
in
data_samples
]
options
=
[
data_sample
.
get
(
'options'
)
for
data_sample
in
data_samples
]
samples
.
update
({
'question'
:
question
[
0
]})
samples
.
update
({
'options'
:
options
[
0
]})
if
data_samples
[
0
].
get
(
'context'
)
is
not
None
:
context
=
[
data_sample
.
get
(
'context'
)
for
data_sample
in
data_samples
]
samples
.
update
({
'context'
:
context
})
data_sample
=
data_samples
[
0
]
img_prompt
=
'###Human: <Img><ImageHere></Img> '
if
'context'
in
samples
:
context_prompt
=
samples
[
'context'
][
0
]
question
=
samples
[
'question'
]
options
=
samples
[
'options'
]
if
'context'
in
samples
:
prompt
=
img_prompt
+
' '
+
context_prompt
+
' '
+
question
+
' '
+
options
# noqa
else
:
prompt
=
img_prompt
+
' '
+
question
+
' '
+
options
# prompt = self.sys_prompt + prompt
prompt
=
prompt
+
'###Assistant:'
image
=
samples
[
'image'
]
img_embeds
,
_
=
self
.
encode_img
(
image
)
# The main process of generation
img_embeds
,
_
=
self
.
encode_img
(
image
)
prompt_segs
=
prompt
.
split
(
'<ImageHere>'
)
prompt_segs
=
prompt
.
split
(
'<ImageHere>'
)
prompt_seg_tokens
=
[
prompt_seg_tokens
=
[
self
.
llama_tokenizer
(
seg
,
self
.
llama_tokenizer
(
seg
,
...
@@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
...
@@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
stopping_criteria
=
self
.
stopping_criteria
,
stopping_criteria
=
self
.
stopping_criteria
,
num_return_sequences
=
1
)
num_return_sequences
=
1
)
output_token
=
outputs
[
0
]
for
i
,
data_sample
in
enumerate
(
data_samples
):
if
output_token
[
0
]
==
0
:
output_token
=
outputs
[
i
]
output_token
=
output_token
[
1
:]
output_text
=
self
.
post_processor
(
output_token
,
if
output_token
[
0
]
==
1
:
self
.
llama_tokenizer
)
output_token
=
output_token
[
1
:]
output_text
=
self
.
llama_tokenizer
.
decode
(
output_token
,
add_special_tokens
=
False
)
output_text
=
self
.
post_process
(
output_text
)
data_sample
.
pred_answer
=
output_text
data_sample
.
pred_answer
=
output_text
return
data_sample
data_samples
[
i
]
=
data_sample
return
data_samples
def
post_process
(
self
,
output_text
):
output_text
=
output_text
.
split
(
'###'
)[
0
]
output_text
=
output_text
.
split
(
'Assistant:'
)[
-
1
].
strip
()
output_text
=
output_text
.
strip
(
'</s><s>'
)
output_text
=
output_text
.
strip
(
'</Img>'
)
output_text
=
output_text
.
strip
()
pattern
=
re
.
compile
(
r
'([A-Z]\.)'
)
res
=
pattern
.
findall
(
output_text
)
if
len
(
res
)
>
0
:
output_text
=
res
[
0
][:
-
1
]
return
output_text
opencompass/multimodal/models/minigpt_4/post_processor.py
0 → 100644
View file @
a205629f
import
re
import
torch
class
MiniGPT4PostProcessor
:
""""Post processor for MiniGPT-4 on MMBench."""
def
__init__
(
self
)
->
None
:
pass
def
__call__
(
self
,
output_token
:
torch
.
tensor
,
tokenizer
)
->
str
:
if
output_token
[
0
]
==
0
:
output_token
=
output_token
[
1
:]
if
output_token
[
0
]
==
1
:
output_token
=
output_token
[
1
:]
output_text
=
tokenizer
.
decode
(
output_token
,
add_special_tokens
=
False
)
# noqa
output_text
=
self
.
_extract_key_words
(
output_text
)
return
output_text
def
_extract_key_words
(
self
,
output_text
:
str
)
->
str
:
output_text
=
output_text
.
split
(
'###'
)[
0
]
output_text
=
output_text
.
split
(
'Assistant:'
)[
-
1
].
strip
()
output_text
=
output_text
.
strip
(
'</s><s>'
)
output_text
=
output_text
.
strip
(
'</Img>'
)
output_text
=
output_text
.
strip
()
pattern
=
re
.
compile
(
r
'([A-Z]\.)'
)
res
=
pattern
.
findall
(
output_text
)
if
len
(
res
)
>
0
:
output_text
=
res
[
0
][:
-
1
]
return
output_text
opencompass/multimodal/models/minigpt_4/prompt_constructor.py
0 → 100644
View file @
a205629f
from
typing
import
List
from
mmpretrain.structures
import
DataSample
class
MiniGPT4MMBenchPromptConstructor
:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def
__init__
(
self
,
image_prompt
:
str
=
''
,
reply_prompt
:
str
=
''
)
->
None
:
self
.
image_prompt
=
image_prompt
self
.
reply_prompt
=
reply_prompt
def
__call__
(
self
,
inputs
:
dict
)
->
dict
:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples
=
inputs
[
'data_samples'
]
prompt
=
self
.
_process
(
data_samples
)
inputs
.
update
({
'prompt'
:
prompt
})
return
inputs
def
_process
(
self
,
data_samples
:
List
[
DataSample
])
->
str
:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert
len
(
data_samples
)
==
1
,
'Only support batch size 1.'
questions
=
[
data_sample
.
get
(
'question'
)
for
data_sample
in
data_samples
]
options
=
[
data_sample
.
get
(
'options'
)
for
data_sample
in
data_samples
]
contexts
=
[
data_sample
.
get
(
'context'
)
for
data_sample
in
data_samples
]
question
=
questions
[
0
]
option
=
options
[
0
]
context
=
contexts
[
0
]
if
context
is
not
None
:
prompt
=
self
.
image_prompt
+
' '
+
context
+
' '
+
question
+
' '
+
option
+
' '
+
self
.
reply_prompt
# noqa
else
:
prompt
=
self
.
image_prompt
+
' '
+
question
+
' '
+
option
+
' '
+
self
.
reply_prompt
# noqa
return
prompt
opencompass/tasks/mm_infer.py
View file @
a205629f
...
@@ -4,7 +4,7 @@ import os
...
@@ -4,7 +4,7 @@ import os
import
os.path
as
osp
import
os.path
as
osp
import
random
import
random
import
time
import
time
from
typing
import
Sequence
from
typing
import
List
,
Sequence
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -78,6 +78,22 @@ class MultimodalInferTask:
...
@@ -78,6 +78,22 @@ class MultimodalInferTask:
return
osp
.
join
(
model_name
,
return
osp
.
join
(
model_name
,
f
'
{
dataset_name
}
-
{
evaluator_name
}
.
{
file_extension
}
'
)
f
'
{
dataset_name
}
-
{
evaluator_name
}
.
{
file_extension
}
'
)
def
get_output_paths
(
self
,
file_extension
:
str
=
'json'
)
->
List
[
str
]:
"""Get the path to the output file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name
=
self
.
model
[
'type'
]
dataset_name
=
self
.
dataloader
[
'dataset'
][
'type'
]
evaluator_name
=
self
.
evaluator
[
0
][
'type'
]
return
[
osp
.
join
(
model_name
,
dataset_name
,
f
'
{
evaluator_name
}
.
{
file_extension
}
'
)
]
def
get_command
(
self
,
cfg_path
,
template
):
def
get_command
(
self
,
cfg_path
,
template
):
"""Get the command template for the task.
"""Get the command template for the task.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment