Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8a8c2982
Commit
8a8c2982
authored
Jul 01, 2024
by
lintangsutawika
Browse files
rework doc_to_visual
parent
7d7a3a1c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
39 deletions
+42
-39
lm_eval/api/task.py
lm_eval/api/task.py
+42
-39
No files found.
lm_eval/api/task.py
View file @
8a8c2982
...
@@ -1278,14 +1278,21 @@ class ConfigurableTask(Task):
...
@@ -1278,14 +1278,21 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_visual
(
self
,
doc
:
dict
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_visual
(
self
,
doc
:
Any
)
->
Union
[
int
,
str
,
list
]:
if
self
.
config
.
doc_to_visual
is
None
:
eval_logger
.
error
(
"doc_to_visual was called but not set in config"
)
else
:
doc_to_visual
=
self
.
config
.
doc_to_visual
if
isinstance
(
self
.
config
.
doc_to_visual
,
str
):
if
isinstance
(
self
.
config
.
doc_to_visual
,
str
):
assert
self
.
config
.
doc_to_visual
in
self
.
features
if
doc_to_visual
in
self
.
features
:
# Single Image. Still return a list for consistency
return
doc
[
doc_to_visual
]
return
doc
[
self
.
config
.
doc_to_visual
]
else
:
else
:
assert
callable
(
self
.
config
.
doc_to_visual
)
return
ast
.
literal_eval
(
utils
.
apply_template
(
doc_to_visual
,
doc
))
return
self
.
config
.
doc_to_visual
(
doc
)
elif
callable
(
doc_to_visual
):
return
doc_to_visual
(
doc
)
else
:
return
None
def
construct_requests
(
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
...
@@ -1307,16 +1314,6 @@ class ConfigurableTask(Task):
...
@@ -1307,16 +1314,6 @@ class ConfigurableTask(Task):
# Otherwise they are placed in the continuation
# Otherwise they are placed in the continuation
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
request_list
=
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
arg
,
idx
=
i
,
**
kwargs
,
)
for
i
,
arg
in
enumerate
(
arguments
)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
():
if
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
():
# if we are calculating multiple choice accuracy
# if we are calculating multiple choice accuracy
...
@@ -1325,31 +1322,37 @@ class ConfigurableTask(Task):
...
@@ -1325,31 +1322,37 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list
.
extend
(
aux_arguments
=
[(
""
,
f
"
{
choice
}
"
)
for
choice
in
choices
]
[
else
:
aux_arguments
=
None
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
arguments
=
(
ctx
,
deepcopy
(
self
.
config
.
generation_kwargs
))
if
self
.
doc_to_visual
:
if
isinstance
(
arguments
,
list
):
arguments
=
[
arg
+
(
self
.
doc_to_visual
(
doc
),)
for
arg
in
arguments
]
else
:
arguments
=
arguments
+
(
self
.
doc_to_visual
(
doc
),)
if
isinstance
(
arguments
,
type
):
if
aux_arguments
is
not
None
:
all_arg_list
=
[
arguments
,
arg_list
]
else
:
all_arg_list
=
[
arguments
]
for
arg_list
in
all_arg_list
:
request_list
=
[
Instance
(
Instance
(
request_type
=
"loglikelihood"
,
request_type
=
"loglikelihood"
,
doc
=
doc
,
doc
=
doc
,
arguments
=
(
""
,
"{}"
.
format
(
choice
))
,
arguments
=
arg
,
idx
=
i
,
idx
=
i
,
**
kwargs
,
**
kwargs
,
)
)
for
i
,
choice
in
enumerate
(
choices
)
for
i
,
arg
in
enumerate
(
arg_list
)
]
]
)
return
request_list
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
return
request_list
if
self
.
INPUT_TYPE
==
"text_image"
:
arguments
=
(
ctx
,
deepcopy
(
self
.
config
.
generation_kwargs
),
self
.
doc_to_visual
,
doc
,
self
.
config
.
task
,
)
elif
self
.
INPUT_TYPE
==
"text"
:
arguments
=
(
ctx
,
deepcopy
(
self
.
config
.
generation_kwargs
))
return
Instance
(
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
request_type
=
self
.
OUTPUT_TYPE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment