Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
97fdc511
Unverified
Commit
97fdc511
authored
Sep 21, 2023
by
Yike Yuan
Committed by
GitHub
Sep 21, 2023
Browse files
[Fix] Fix performance issue of visualglm. (#424)
* [Fix] Visualglm performance fixed. * [Fix] Hide ckpt path.
parent
8803f7f7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
113 additions
and
135 deletions
+113
-135
configs/multimodal/visualglm/visualglm_6b_coco_caption.py
configs/multimodal/visualglm/visualglm_6b_coco_caption.py
+1
-1
configs/multimodal/visualglm/visualglm_6b_flickr30k.py
configs/multimodal/visualglm/visualglm_6b_flickr30k.py
+1
-1
opencompass/multimodal/models/visualglm/post_processor.py
opencompass/multimodal/models/visualglm/post_processor.py
+4
-6
opencompass/multimodal/models/visualglm/prompt_constructor.py
...compass/multimodal/models/visualglm/prompt_constructor.py
+79
-89
opencompass/multimodal/models/visualglm/visualglm.py
opencompass/multimodal/models/visualglm/visualglm.py
+28
-38
No files found.
configs/multimodal/visualglm/visualglm_6b_coco_caption.py
View file @
97fdc511
...
...
@@ -32,7 +32,7 @@ visualglm_coco_caption_model = dict(
type
=
'visualglm'
,
pretrained_path
=
'/path/to/visualglm'
,
# or Huggingface repo id
is_caption_task
=
True
,
prompt_constructor
=
dict
(
type
=
VisualGLMBasePromptConstructor
,
system_prompt
=
'
A photo of
'
),
prompt_constructor
=
dict
(
type
=
VisualGLMBasePromptConstructor
,
system_prompt
=
'
Describe the image.
'
),
post_processor
=
dict
(
type
=
VisualGLMBasePostProcessor
)
)
...
...
configs/multimodal/visualglm/visualglm_6b_flickr30k.py
View file @
97fdc511
...
...
@@ -33,7 +33,7 @@ visualglm_flickr30k_model = dict(
type
=
'visualglm'
,
pretrained_path
=
'/path/to/visualglm'
,
# or Huggingface repo id
is_caption_task
=
True
,
prompt_constructor
=
dict
(
type
=
VisualGLMBasePromptConstructor
,
system_prompt
=
'
A photo of
'
),
prompt_constructor
=
dict
(
type
=
VisualGLMBasePromptConstructor
,
system_prompt
=
'
Describe the image.
'
),
post_processor
=
dict
(
type
=
VisualGLMBasePostProcessor
)
)
...
...
opencompass/multimodal/models/visualglm/post_processor.py
View file @
97fdc511
...
...
@@ -9,9 +9,8 @@ class VisualGLMBasePostProcessor:
def
__init__
(
self
)
->
None
:
pass
def
__call__
(
self
,
output_token
:
torch
.
tensor
,
tokenizer
:
Any
,
input_len
:
int
)
->
str
:
return
tokenizer
.
decode
(
output_token
[
input_len
:])
def
__call__
(
self
,
output_token
:
torch
.
tensor
,
tokenizer
:
Any
)
->
str
:
return
tokenizer
.
decode
(
output_token
)
class
VisualGLMVSRPostProcessor
(
VisualGLMBasePostProcessor
):
...
...
@@ -20,9 +19,8 @@ class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
__call__
(
self
,
output_token
:
torch
.
tensor
,
tokenizer
:
Any
,
input_len
:
int
)
->
str
:
output_text
=
tokenizer
.
decode
(
output_token
[
input_len
:])
def
__call__
(
self
,
output_token
:
torch
.
tensor
,
tokenizer
:
Any
)
->
str
:
output_text
=
tokenizer
.
decode
(
output_token
)
if
'yes'
in
output_text
.
lower
():
return
'yes'
elif
'no'
in
output_text
.
lower
():
...
...
opencompass/multimodal/models/visualglm/prompt_constructor.py
View file @
97fdc511
import
torch
class
VisualGLMMMBenchPromptConstructor
:
"""MMBench prompt constructor for VisualGLM.
The overall prompt will be formulated as
"system_prompt"+"human_prompt"+"image_prompt"+question+"assistant+prompt".
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
image_prompt (str): Image prompt. (Default: '<img></img>')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def
__init__
(
self
,
system_prompt
:
str
=
''
,
human_prompt
:
str
=
'Q:'
,
image_prompt
:
str
=
'<img></img>'
,
assistant_prompt
:
str
=
'A:'
)
->
None
:
self
.
image_prompt
=
image_prompt
self
.
system_prompt
=
system_prompt
self
.
human_prompt
=
human_prompt
self
.
assistant_prompt
=
assistant_prompt
...
...
@@ -33,26 +25,18 @@ class VisualGLMMMBenchPromptConstructor:
A tuple containing images, prompt, data_samples and image_position.
"""
images
=
batch
.
pop
(
'inputs'
)
images
=
torch
.
stack
(
images
,
dim
=
0
)
data_samples
=
batch
.
pop
(
'data_samples'
)
questions
=
[
sample
.
get
(
'question'
)
for
sample
in
data_samples
]
options
=
[
sample
.
get
(
'options'
)
for
sample
in
data_samples
]
contexts
=
[
sample
.
get
(
'context'
)
for
sample
in
data_samples
]
contexts
=
[
c
if
c
else
''
for
c
in
contexts
]
# generate text prompt
prompt
=
[
'{}{}{}{}{}{}{}'
.
format
(
self
.
system_prompt
,
self
.
image_prompt
,
self
.
human_prompt
,
context
,
question
,
option
,
self
.
assistant_prompt
)
for
context
,
question
,
option
in
zip
(
contexts
,
questions
,
options
)
]
assert
len
(
batch
[
'inputs'
])
==
1
image
=
batch
.
pop
(
'inputs'
)[
0
].
unsqueeze
(
0
)
data_sample
=
batch
.
pop
(
'data_samples'
)[
0
]
img_prompt
=
'<img></img>'
if
data_sample
.
get
(
'context'
)
is
not
None
:
prompt
=
img_prompt
+
self
.
system_prompt
+
self
.
human_prompt
+
data_sample
.
context
+
' '
+
data_sample
.
question
+
' '
+
data_sample
.
options
# noqa
else
:
prompt
=
img_prompt
+
self
.
system_prompt
+
self
.
human_prompt
+
data_sample
.
question
+
' '
+
data_sample
.
options
# noqa
prompt
+=
self
.
assistant_prompt
image_position
=
prompt
.
rfind
(
'<img>'
)
+
5
image_position
=
5
return
images
,
prompt
,
data_samples
,
image_position
return
image
,
prompt
,
data_sample
,
image_position
class
VisualGLMBasePromptConstructor
:
...
...
@@ -61,10 +45,17 @@ class VisualGLMBasePromptConstructor:
The prompt will concat <img> and the given system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def
__init__
(
self
,
system_prompt
=
''
)
->
None
:
def
__init__
(
self
,
system_prompt
:
str
=
''
,
human_prompt
:
str
=
'Q:'
,
assistant_prompt
:
str
=
'A:'
)
->
None
:
self
.
prompt
=
system_prompt
self
.
human_prompt
=
human_prompt
self
.
assistant_prompt
=
assistant_prompt
def
__call__
(
self
,
batch
:
dict
)
->
tuple
:
"""Construct prompt.
...
...
@@ -76,16 +67,16 @@ class VisualGLMBasePromptConstructor:
A tuple containing images, prompt, data_samples and image_position.
"""
images
=
batch
.
pop
(
'inputs'
)
image
s
=
torch
.
stack
(
images
,
dim
=
0
)
data_sample
s
=
batch
.
pop
(
'data_samples'
)
assert
len
(
batch
[
'inputs'
])
==
1
image
=
batch
.
pop
(
'inputs'
)[
0
].
unsqueeze
(
0
)
data_sample
=
batch
.
pop
(
'data_samples'
)
[
0
]
# generate text prompt
prompt
=
[
'<img></img>'
+
self
.
prompt
for
i
in
range
(
images
.
shape
[
0
])]
prompt
=
'<img></img>'
+
self
.
human_
prompt
+
self
.
prompt
+
self
.
assistant_prompt
# noqa
image_position
=
5
image_position
=
prompt
.
rfind
(
'<img>'
)
+
5
return
image
s
,
prompt
,
data_sample
s
,
image_position
return
image
,
prompt
,
data_sample
,
image_position
class
VisualGLMVQAPromptConstructor
(
VisualGLMBasePromptConstructor
):
...
...
@@ -94,10 +85,15 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat <img>, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def
__init__
(
self
,
system_prompt
=
''
)
->
None
:
super
().
__init__
(
system_prompt
)
def
__init__
(
self
,
system_prompt
=
''
,
human_prompt
:
str
=
'Q:'
,
assistant_prompt
:
str
=
'A:'
)
->
None
:
super
().
__init__
(
system_prompt
,
human_prompt
,
assistant_prompt
)
def
__call__
(
self
,
batch
:
dict
)
->
tuple
:
"""Construct prompt.
...
...
@@ -109,19 +105,18 @@ class VisualGLMVQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
images
=
batch
.
pop
(
'inputs'
)
images
=
torch
.
stack
(
images
,
dim
=
0
)
data_samples
=
batch
.
pop
(
'data_samples'
)
questions
=
[
sample
.
get
(
'question'
)
for
sample
in
data_samples
]
assert
len
(
batch
[
'inputs'
])
==
1
image
=
batch
.
pop
(
'inputs'
)[
0
].
unsqueeze
(
0
)
data_sample
=
batch
.
pop
(
'data_samples'
)[
0
]
# generate text prompt
prompt
=
[
'<img></img>Q:{} {}
\n
A:'
.
format
(
question
,
self
.
prompt
)
for
question
in
questions
]
image_position
=
5
question
=
data_sample
.
get
(
'question'
)
prompt
=
'<img></img>'
+
self
.
human_prompt
+
question
+
self
.
prompt
prompt
+=
'
\n
'
+
self
.
assistant_prompt
return
images
,
prompt
,
data_samples
,
image_position
image_position
=
prompt
.
rfind
(
'<img>'
)
+
5
return
image
,
prompt
,
data_sample
,
image_position
class
VisualGLMScienceQAPromptConstructor
(
VisualGLMBasePromptConstructor
):
...
...
@@ -130,12 +125,17 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat image and all terms in a question.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
choice_mapping
=
{
0
:
'A'
,
1
:
'B'
,
2
:
'C'
,
3
:
'D'
,
4
:
'E'
,
5
:
'F'
}
def
__init__
(
self
,
system_prompt
=
''
)
->
None
:
super
().
__init__
(
system_prompt
)
def
__init__
(
self
,
system_prompt
=
''
,
human_prompt
:
str
=
'Q:'
,
assistant_prompt
:
str
=
'A:'
)
->
None
:
super
().
__init__
(
system_prompt
,
human_prompt
,
assistant_prompt
)
def
__call__
(
self
,
batch
:
dict
)
->
tuple
:
"""Construct prompt.
...
...
@@ -147,33 +147,24 @@ class VisualGLMScienceQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
images
=
batch
.
pop
(
'inputs'
)
images
=
torch
.
stack
(
images
,
dim
=
0
)
data_samples
=
batch
.
pop
(
'data_samples'
)
questions
=
[
'Q: '
+
sample
.
get
(
'question'
)
+
'
\n
'
for
sample
in
data_samples
]
choices
=
[
sample
.
get
(
'choices'
)
for
sample
in
data_samples
]
choices
=
[[
f
'(
{
self
.
choice_mapping
[
i
]
}
) '
+
item
for
i
,
item
in
enumerate
(
choice
)
]
for
choice
in
choices
]
assert
len
(
batch
[
'inputs'
])
==
1
image
=
batch
.
pop
(
'inputs'
)[
0
].
unsqueeze
(
0
)
data_sample
=
batch
.
pop
(
'data_samples'
)[
0
]
questions
=
'Question: '
+
data_sample
.
get
(
'question'
)
choices
=
data_sample
.
get
(
'choices'
)
choices
=
[
'Choices: '
+
' '
.
join
(
choice
)
+
'
\n
'
for
choice
in
choices
]
# noqa
contexts
=
[
'Context: '
+
data_sample
.
get
(
'hint'
)
+
'
\n
'
for
data_sample
in
data_samples
]
# noqa
f
'(
{
self
.
choice_mapping
[
i
]
}
) '
+
item
for
i
,
item
in
enumerate
(
choices
)
]
choices
=
'Choices: '
+
' '
.
join
(
choices
)
+
'
\n
'
contexts
=
'Context: '
+
data_sample
.
get
(
'hint'
)
+
'
\n
'
# generate text prompt
prompt
=
[
'<img></img>'
+
context
+
question
+
choice
+
self
.
prompt
for
context
,
question
,
choice
in
zip
(
contexts
,
questions
,
choices
)
]
image_position
=
5
prompt
=
'<img></img>'
+
self
.
human_prompt
+
contexts
+
questions
+
choices
+
self
.
prompt
+
self
.
assistant_prompt
# noqa
image_position
=
prompt
.
rfind
(
'<img>'
)
+
5
return
image
s
,
prompt
,
data_sample
s
,
image_position
return
image
,
prompt
,
data_sample
,
image_position
class
VisualGLMIconQAPromptConstructor
(
VisualGLMBasePromptConstructor
):
...
...
@@ -182,10 +173,15 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
The prompt will concat <img>, the question and the system prompt.
Args:
system_prompt (str): System prompt. (Default: '')
human_prompt (str): Human prompt. (Default: 'Q:')
assistant_prompt (str): Assistant prompt. (Default: 'A:')
"""
def
__init__
(
self
,
system_prompt
=
''
)
->
None
:
super
().
__init__
(
system_prompt
)
def
__init__
(
self
,
system_prompt
=
''
,
human_prompt
:
str
=
'Q:'
,
assistant_prompt
:
str
=
'A:'
)
->
None
:
super
().
__init__
(
system_prompt
,
human_prompt
,
assistant_prompt
)
def
__call__
(
self
,
batch
:
dict
)
->
tuple
:
"""Construct prompt.
...
...
@@ -197,22 +193,16 @@ class VisualGLMIconQAPromptConstructor(VisualGLMBasePromptConstructor):
A tuple containing images, prompt, data_samples and image_position.
"""
images
=
batch
.
pop
(
'inputs'
)
images
=
torch
.
stack
(
images
,
dim
=
0
)
data_samples
=
batch
.
pop
(
'data_samples'
)
questions
=
[
'Q: '
+
sample
.
get
(
'question'
)
+
'
\n
'
for
sample
in
data_samples
]
choices
=
[
sample
.
get
(
'choices'
)
for
sample
in
data_samples
]
choices
=
[
'Options: '
+
', '
.
join
(
choice
)
+
'.
\n
'
for
choice
in
choices
]
# noqa
assert
len
(
batch
[
'inputs'
])
==
1
image
=
batch
.
pop
(
'inputs'
)[
0
].
unsqueeze
(
0
)
data_sample
=
batch
.
pop
(
'data_samples'
)[
0
]
questions
=
data_sample
.
get
(
'question'
)
+
'
\n
'
choices
=
data_sample
.
get
(
'choices'
)
choices
=
'Options: '
+
', '
.
join
(
choices
)
+
'.
\n
'
# generate text prompt
prompt
=
[
'<img></img>'
+
question
+
choice
+
self
.
prompt
for
question
,
choice
in
zip
(
questions
,
choices
)
]
image_position
=
5
prompt
=
'<img></img>'
+
self
.
human_prompt
+
questions
+
choices
+
self
.
prompt
+
self
.
assistant_prompt
# noqa
image_position
=
prompt
.
rfind
(
'<img>'
)
+
5
return
image
s
,
prompt
,
data_sample
s
,
image_position
return
image
,
prompt
,
data_sample
,
image_position
opencompass/multimodal/models/visualglm/visualglm.py
View file @
97fdc511
...
...
@@ -43,39 +43,31 @@ class VisualGLM(nn.Module):
if
gen_kwargs
:
self
.
gen_kwargs
=
gen_kwargs
else
:
self
.
gen_kwargs
=
dict
(
m
ax_new_tokens
=
3
0
,
num_beams
=
1
,
do_sample
=
False
,
repetition_penalty
=
1.0
,
length_penalty
=-
1.
0
,
)
self
.
gen_kwargs
=
dict
(
max_length
=
1024
,
m
in_length
=
10
0
,
do_sample
=
True
,
temperature
=
0.8
,
top_p
=
0.4
,
top_k
=
10
0
,
repetition_penalty
=
1.2
)
self
.
is_caption_task
=
is_caption_task
def
encode_by_tokenizer
(
self
,
multi_prompts
,
image_position
):
input_ids
=
[]
max_seq_length
=
0
for
prompt
in
multi_prompts
:
input0
=
self
.
tokenizer
.
encode
(
prompt
[:
image_position
],
add_special_tokens
=
False
)
input1
=
[
self
.
tokenizer
.
pad_token_id
]
*
self
.
model
.
image_length
input2
=
self
.
tokenizer
.
encode
(
prompt
[
image_position
:],
add_special_tokens
=
False
)
input_all
=
sum
([
input0
,
input1
,
input2
],
[])
input_all
=
self
.
tokenizer
.
build_inputs_with_special_tokens
(
input_all
)
max_seq_length
=
max
(
max_seq_length
,
len
(
input_all
))
input_ids
.
append
(
input_all
)
pre_image_len
=
len
(
input0
)
def
encode_by_tokenizer
(
self
,
prompt
,
image_position
):
# padding
for
i
,
_
in
enumerate
(
input_ids
):
pad_len
=
max_seq_length
-
len
(
input_ids
[
i
])
input_ids
[
i
]
=
[
self
.
tokenizer
.
pad_token_id
]
*
pad_len
+
input_ids
[
i
]
input0
=
self
.
tokenizer
.
encode
(
prompt
[:
image_position
],
add_special_tokens
=
False
)
input1
=
[
self
.
tokenizer
.
unk_token_id
]
*
self
.
model
.
image_length
input2
=
self
.
tokenizer
.
encode
(
prompt
[
image_position
:],
add_special_tokens
=
False
)
input_all
=
sum
([
input0
,
input1
,
input2
],
[])
input_all
=
self
.
tokenizer
.
build_inputs_with_special_tokens
(
input_all
)
input_all
=
torch
.
tensor
(
input_all
,
dtype
=
torch
.
long
).
to
(
get_device
())
input_all
=
input_all
.
unsqueeze
(
0
)
return
input_ids
,
pre_image_len
pre_image_len
=
len
(
input0
)
return
input_all
,
pre_image_len
def
generate
(
self
,
batch
):
# process input
...
...
@@ -87,26 +79,24 @@ class VisualGLM(nn.Module):
input_all
,
pre_image_len
=
self
.
encode_by_tokenizer
(
prompt
,
image_position
)
input_all
=
torch
.
tensor
(
input_all
,
dtype
=
torch
.
long
).
to
(
get_device
())
# build input param
inputs
=
{
'input_ids'
:
input_all
,
'pre_image_length'
:
pre_image_len
,
'images'
:
image
}
# generate answer
outputs
=
self
.
model
.
generate
(
**
inputs
,
**
self
.
gen_kwargs
)
# format output
outputs
=
outputs
.
tolist
()
for
i
,
sample
in
enumerate
(
data_sample
):
answer
=
self
.
post_processor
(
outputs
[
i
],
self
.
tokenizer
,
input_all
.
shape
[
1
])
if
self
.
is_caption_task
:
data_sample
[
i
].
pred_caption
=
answer
else
:
data_sample
[
i
].
pred_answer
=
answer
outputs
=
outputs
.
tolist
()[
0
][
input_all
.
shape
[
1
]:]
answer
=
self
.
post_processor
(
outputs
,
self
.
tokenizer
)
if
self
.
is_caption_task
:
data_sample
.
pred_caption
=
answer
else
:
data_sample
.
pred_answer
=
answer
return
data_sample
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment