Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
3c772593
Commit
3c772593
authored
Oct 20, 2024
by
Baber
Browse files
add attn_mask (llava models need it)
parent
4142b7b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
10 deletions
+26
-10
lm_eval/models/hf_vlms.py
lm_eval/models/hf_vlms.py
+26
-10
No files found.
lm_eval/models/hf_vlms.py
View file @
3c772593
import
copy
import
copy
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -21,6 +21,9 @@ from lm_eval.models.utils import (
...
@@ -21,6 +21,9 @@ from lm_eval.models.utils import (
from
lm_eval.utils
import
add_padding_if_needed
from
lm_eval.utils
import
add_padding_if_needed
if
TYPE_CHECKING
:
import
PIL
DEFAULT_IMAGE_PLACEHOLDER
=
"<image>"
DEFAULT_IMAGE_PLACEHOLDER
=
"<image>"
...
@@ -175,7 +178,9 @@ class HFMultimodalLM(HFLM):
...
@@ -175,7 +178,9 @@ class HFMultimodalLM(HFLM):
return
text_encoding
,
encoding
# image_encoding is a dict
return
text_encoding
,
encoding
# image_encoding is a dict
def
_encode_multimodal_pair
(
self
,
context
,
continuation
,
images
):
def
_encode_multimodal_pair
(
self
,
context
,
continuation
,
images
:
List
[
"PIL.Image.Image"
]
):
"""Helper function to perform the role of TemplateLM._encode_pair
"""Helper function to perform the role of TemplateLM._encode_pair
Except allowing for image input to also be processed alongside `context`.
Except allowing for image input to also be processed alongside `context`.
...
@@ -192,6 +197,9 @@ class HFMultimodalLM(HFLM):
...
@@ -192,6 +197,9 @@ class HFMultimodalLM(HFLM):
context
,
DEFAULT_IMAGE_PLACEHOLDER
,
self
.
image_token
,
self
.
max_images
context
,
DEFAULT_IMAGE_PLACEHOLDER
,
self
.
image_token
,
self
.
max_images
)
)
if
self
.
rgb
:
images
=
[
img
.
convert
(
"RGB"
)
for
img
in
images
]
whole_enc
,
image_enc
=
self
.
tok_multimodal_encode
(
whole_enc
,
image_enc
=
self
.
tok_multimodal_encode
(
context
+
continuation
,
images
context
+
continuation
,
images
)
)
...
@@ -346,7 +354,7 @@ class HFMultimodalLM(HFLM):
...
@@ -346,7 +354,7 @@ class HFMultimodalLM(HFLM):
"""
"""
# note: imgs is a dict.
# note: imgs is a dict.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
model
(
inps
,
attention_mask
=
torch
.
ones_like
(
inps
),
**
imgs
).
logits
return
self
.
model
(
inps
,
**
imgs
,
attention_mask
=
attn_mask
).
logits
def
_model_multimodal_generate
(
self
,
inputs
,
max_length
,
stop
,
**
generation_kwargs
):
def
_model_multimodal_generate
(
self
,
inputs
,
max_length
,
stop
,
**
generation_kwargs
):
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
...
@@ -384,7 +392,9 @@ class HFMultimodalLM(HFLM):
...
@@ -384,7 +392,9 @@ class HFMultimodalLM(HFLM):
batched_imgs
[
key
]
=
torch
.
cat
(
batched_imgs
[
key
]
=
torch
.
cat
(
[
[
torch
.
tensor
(
torch
.
tensor
(
image_enc
[
key
],
device
=
self
.
device
,
dtype
=
self
.
model
.
dtype
image_enc
[
key
],
device
=
self
.
device
,
dtype
=
self
.
model
.
dtype
if
key
==
"pixel_values"
else
torch
.
int
,
)
)
for
image_enc
in
image_encs
for
image_enc
in
image_encs
],
],
...
@@ -453,15 +463,16 @@ class HFMultimodalLM(HFLM):
...
@@ -453,15 +463,16 @@ class HFMultimodalLM(HFLM):
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
# groups requests by context+continuation[:-1] and infer on one request/group.
return
req
[
-
1
]
+
req
[
-
3
]
+
req
[
-
2
]
[:
-
1
]
return
req
[
-
3
]
+
req
[
-
2
]
re_ord
=
Collator
(
re_ord
=
Collator
(
requests
,
requests
,
sort_fn
=
_collate
,
sort_fn
=
_collate
,
group_by
=
"contexts"
# TODO: can't group-by just "contexts" any more, need to incorporate imgs
group_by
=
None
,
if
self
.
backend
==
"causal"
and
self
.
logits_cache
# group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
else
None
,
# if self.backend == "causal" and self.logits_cache
group_fn
=
_lookup_one_token_cont
,
# else None,
# group_fn=_lookup_one_token_cont,
)
)
# automatic (variable) batch size detection for vectorization
# automatic (variable) batch size detection for vectorization
...
@@ -545,7 +556,12 @@ class HFMultimodalLM(HFLM):
...
@@ -545,7 +556,12 @@ class HFMultimodalLM(HFLM):
)
# TODO: fix/test for bs>1 case with differently-sized imgs!
)
# TODO: fix/test for bs>1 case with differently-sized imgs!
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_multimodal_call
(
batched_inps
,
batched_imgs
,
**
call_kwargs
),
self
.
_model_multimodal_call
(
batched_inps
,
batched_imgs
,
attn_mask
=
torch
.
ones_like
(
batched_inps
),
**
call_kwargs
,
),
dim
=-
1
,
dim
=-
1
,
)
# [batch, padding_length (inp or cont), vocab]
)
# [batch, padding_length (inp or cont), vocab]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment