Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8181f43c
Commit
8181f43c
authored
Jan 16, 2025
by
Baber
Browse files
nits
parent
2106fbeb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
21 deletions
+60
-21
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+2
-4
lm_eval/tasks/mathvista/mathvista.yaml
lm_eval/tasks/mathvista/mathvista.yaml
+5
-1
lm_eval/tasks/mathvista/utils.py
lm_eval/tasks/mathvista/utils.py
+53
-16
No files found.
lm_eval/models/openai_completions.py
View file @
8181f43c
...
...
@@ -5,8 +5,8 @@ import itertools
import
json
import
os
from
functools
import
cached_property
from
operator
import
itemgetter
from
io
import
BytesIO
from
operator
import
itemgetter
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
...
...
@@ -15,10 +15,8 @@ from tqdm import tqdm
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.api_models
import
TemplateAPI
from
lm_eval.models.utils
import
handle_stop_sequences
from
lm_eval.models.api_models
import
JsonChatStr
,
TemplateAPI
from
lm_eval.models.utils
import
Collator
from
lm_eval.models.utils
import
Collator
,
handle_stop_sequences
from
lm_eval.utils
import
eval_logger
...
...
lm_eval/tasks/mathvista/mathvista.yaml
View file @
8181f43c
...
...
@@ -6,7 +6,6 @@ output_type: "generate_until"
doc_to_image
:
-
decoded_image
doc_to_text
:
"
<image>{{query}}"
#doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}'
doc_to_target
:
answer
process_results
:
!function
utils.process_results
generation_kwargs
:
...
...
@@ -15,6 +14,11 @@ generation_kwargs:
temperature
:
0.0
do_sample
:
false
max_gen_toks
:
1024
filter_list
:
-
name
:
"
extract_answer"
filter
:
-
function
:
"
custom"
filter_fn
:
!function
utils.build_predictions
metric_list
:
-
metric
:
acc
aggregation
:
mean
...
...
lm_eval/tasks/mathvista/utils.py
View file @
8181f43c
import
re
from
typing
import
Optional
# from api_model import make_concurrent_requests
from
Levenshtein
import
distance
...
...
@@ -53,6 +54,13 @@ def create_test_prompt(demo_prompt, query, response):
return
full_prompt
def
verify_extraction
(
extraction
):
extraction
=
extraction
.
strip
()
if
not
extraction
:
return
False
return
True
# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py
def
get_most_similar
(
prediction
:
str
,
choices
:
list
)
->
float
:
"""
...
...
@@ -140,24 +148,16 @@ def safe_equal(prediction, answer):
return
False
def
extract_answer
(
response
:
str
,
problem
:
dict
)
->
str
:
def
extract_answer
(
response
:
str
,
problem
:
dict
,
quick_extract
=
True
)
->
str
:
question_type
=
problem
[
"question_type"
]
answer_type
=
problem
[
"answer_type"
]
choices
=
problem
[
"choices"
]
# query = problem["query"]
# pid = problem[
'
pid
'
]
# pid = problem[
"
pid
"
]
if
response
==
""
:
return
""
### This is not in the original code:
extract
=
re
.
findall
(
r
"[tT]he answer is ([A-Za-z0-9]+(?:\.[A-Za-z0-9]+)?)"
,
response
)
if
extract
:
return
str
(
extract
[
0
])
###
if
question_type
==
"multi_choice"
and
response
in
choices
:
return
response
...
...
@@ -175,7 +175,39 @@ def extract_answer(response: str, problem: dict) -> str:
except
Exception
:
pass
return
response
# quick extraction
if
quick_extract
:
# The answer is "text". -> "text"
try
:
result
=
re
.
search
(
r
'The answer is "(.*)"\.'
,
response
)
if
result
:
extraction
=
result
.
group
(
1
)
return
extraction
except
Exception
:
pass
# general extraction
# try:
# full_prompt = create_test_prompt(DEMO_PROMPT, query, response)
# extraction = make_concurrent_requests(full_prompt)
# return extraction
# except Exception:
# print(
# f"Error in extracting answer for problem: {pid} with response: {response}"
# )
# # logging.info(f"Error in extracting answer for problem: {pid} with response: {response}")
# # logging.info(e)
return
""
def
extract_all_answers
(
resps
:
list
[
list
[
str
]],
docs
:
dict
,
quick_extract
=
True
)
->
list
[
str
]:
return
[
extract_answer
(
resp
[
0
],
doc
,
quick_extract
=
quick_extract
)
for
resp
,
doc
in
zip
(
resps
,
docs
)
]
# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py
...
...
@@ -186,11 +218,16 @@ def process_results(doc: dict, results: list[str]):
answer_type
=
doc
[
"answer_type"
]
precision
=
doc
[
"precision"
]
# noqa: F841
answer
=
doc
[
"answer"
]
extracted_answer
=
extract_answer
(
response
,
doc
)
normalized_extraction
=
normalize_extracted_answer
(
extracted_answer
,
choices
,
question_type
,
answer_type
,
precision
)
res
=
safe_equal
(
normalized_extraction
,
answer
)
# step 1: extract the answer from the model response
# extracted_answer = extract_answer(response, doc)
extracted_answer
=
response
[
0
]
if
verify_extraction
(
extracted_answer
):
normalized_extraction
=
normalize_extracted_answer
(
extracted_answer
,
choices
,
question_type
,
answer_type
,
precision
)
res
=
safe_equal
(
normalized_extraction
,
answer
)
else
:
res
=
False
return
{
"acc"
:
1.0
}
if
res
else
{
"acc"
:
0.0
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment