Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0dcdfb80
Commit
0dcdfb80
authored
Jul 01, 2024
by
JessicaOjo
Browse files
remove openai fixes and unused regex
parent
ebe41744
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
159 deletions
+16
-159
lm_eval/api/instance.py
lm_eval/api/instance.py
+1
-1
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+3
-16
lm_eval/api/registry.py
lm_eval/api/registry.py
+0
-1
lm_eval/api/task.py
lm_eval/api/task.py
+12
-53
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+0
-42
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+0
-46
No files found.
lm_eval/api/instance.py
View file @
0dcdfb80
...
...
@@ -3,7 +3,7 @@ from typing import Literal, Optional, Tuple
OutputType
=
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
,
"multiple_choice_gpt"
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
]
...
...
lm_eval/api/metrics.py
View file @
0dcdfb80
...
...
@@ -23,20 +23,7 @@ def bypass_agg(arr):
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
if
isinstance
(
arr
[
0
],
(
list
,
np
.
ndarray
)):
return
sum
(
arr
[
0
])
/
len
(
arr
[
0
])
else
:
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"acc_gpt"
)
def
acc_gpt
(
arr
):
unzipped_list
=
list
(
zip
(
*
arr
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
accuracy
=
sklearn
.
metrics
.
accuracy_score
(
golds
,
preds
)
return
accuracy
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
...
...
@@ -165,7 +152,7 @@ def brier_score_fn(items): # This is a passthrough function
@
register_metric
(
metric
=
"acc"
,
higher_is_better
=
True
,
output_type
=
[
"loglikelihood"
,
"multiple_choice"
,
"multiple_choice_gpt"
],
output_type
=
[
"loglikelihood"
,
"multiple_choice"
],
aggregation
=
"mean"
,
)
def
acc_fn
(
items
):
# This is a passthrough function
...
...
@@ -291,7 +278,7 @@ def mcc_fn(items): # This is a passthrough function
@
register_metric
(
metric
=
"f1"
,
higher_is_better
=
True
,
output_type
=
[
"multiple_choice"
,
"multiple_choice_gpt"
],
output_type
=
[
"multiple_choice"
],
aggregation
=
"f1"
,
)
def
f1_fn
(
items
):
# This is a passthrough function
...
...
lm_eval/api/registry.py
View file @
0dcdfb80
...
...
@@ -87,7 +87,6 @@ DEFAULT_METRIC_REGISTRY = {
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
"acc_norm"
],
"multiple_choice_gpt"
:
[
"acc"
],
"generate_until"
:
[
"exact_match"
],
}
...
...
lm_eval/api/task.py
View file @
0dcdfb80
...
...
@@ -44,7 +44,6 @@ from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES
=
[
"loglikelihood"
,
"multiple_choice"
,
"multiple_choice_gpt"
,
"loglikelihood_rolling"
,
"generate_until"
,
]
...
...
@@ -1268,7 +1267,7 @@ class ConfigurableTask(Task):
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
arguments
=
(
self
.
doc_to_target
(
doc
),)
elif
"multiple_choice"
in
self
.
OUTPUT_TYPE
:
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
choices
=
self
.
doc_to_choice
(
doc
)
target_delimiter
=
self
.
config
.
target_delimiter
if
self
.
multiple_input
:
...
...
@@ -1280,28 +1279,16 @@ class ConfigurableTask(Task):
else
:
# Otherwise they are placed in the continuation
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
if
self
.
OUTPUT_TYPE
==
"multiple_choice_gpt"
:
request_list
=
[
Instance
(
request_type
=
"multiple_choice_gpt"
,
doc
=
doc
,
arguments
=
arg
,
idx
=
i
,
**
kwargs
,
)
for
i
,
arg
in
enumerate
(
arguments
)
]
else
:
request_list
=
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
arg
,
idx
=
i
,
**
kwargs
,
)
for
i
,
arg
in
enumerate
(
arguments
)
]
request_list
=
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
arg
,
idx
=
i
,
**
kwargs
,
)
for
i
,
arg
in
enumerate
(
arguments
)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
():
# if we are calculating multiple choice accuracy
...
...
@@ -1446,35 +1433,6 @@ class ConfigurableTask(Task):
]
acc_mutual_info
=
1.0
if
np
.
argmax
(
lls_mutual_info
)
==
gold
else
0.0
result_dict
[
"acc_mutual_info"
]
=
acc_mutual_info
elif
self
.
OUTPUT_TYPE
==
"multiple_choice_gpt"
:
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
choices
=
self
.
doc_to_choice
(
doc
)
try
:
gold
=
choices
[
gold
]
gold
=
type
(
result
)(
gold
)
except
TypeError
:
gold
=
gold
for
metric
in
self
.
_metric_fn_list
.
keys
():
try
:
result_score
=
self
.
_metric_fn_list
[
metric
](
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
metric
],
)
except
(
TypeError
):
# TODO: this is hacky and I don't want to do it
result_score
=
self
.
_metric_fn_list
[
metric
](
[
gold
,
result
]
)
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
result_score
=
result_score
[
metric
]
result_dict
[
metric
]
=
result_score
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
...
...
@@ -1534,6 +1492,7 @@ class ConfigurableTask(Task):
result_score
=
0.0
else
:
try
:
# adds exact match logic
if
metric
==
"exact_match"
:
result_score
=
self
.
_metric_fn_list
[
metric
](
references
=
[
str
(
gold
)],
...
...
lm_eval/filters/extraction.py
View file @
0dcdfb80
...
...
@@ -86,48 +86,6 @@ class VerbalizerFilter(Filter):
return
filtered_resps
@
register_filter
(
"regex-numbers"
)
class
RegexNumberFilter
(
Filter
):
""" """
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
0
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
,
docs
):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def
filter_set
(
inst
):
filtered
=
[]
for
resp
in
inst
:
match
=
self
.
regex
.
findall
(
resp
)
if
match
:
match
=
match
[
self
.
group_select
]
if
isinstance
(
match
,
tuple
):
match
=
[
m
for
m
in
match
if
m
][
0
]
match
=
match
.
strip
().
replace
(
','
,
''
).
replace
(
'.'
,
''
).
lower
()
else
:
match
=
self
.
fallback
filtered
.
append
(
match
)
return
filtered
filtered_resps
=
list
(
map
(
lambda
x
:
filter_set
(
x
),
resps
))
return
filtered_resps
@
register_filter
(
"remove_whitespace"
)
class
WhitespaceFilter
(
Filter
):
""" """
...
...
lm_eval/models/openai_completions.py
View file @
0dcdfb80
...
...
@@ -471,52 +471,6 @@ class OpenaiChatCompletionsLM(LM):
return
grouper
.
get_original
(
res
)
def
multiple_choice_gpt
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
List
[
str
]:
res
=
defaultdict
(
list
)
re_ords
=
{}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper
=
lm_eval
.
models
.
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords
[
key
]
=
utils
.
Reorderer
(
[
req
.
args
for
req
in
reqs
],
lambda
x
:
(
-
len
(
x
[
0
]),
x
[
0
])
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)))
for
key
,
re_ord
in
re_ords
.
items
():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks
=
lm_eval
.
models
.
utils
.
chunks
(
re_ord
.
get_reordered
(),
n
=
1
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
inps
=
[{
"role"
:
"user"
,
"content"
:
context
}
for
context
in
contexts
]
response
=
oa_completion
(
client
=
self
.
client
,
chat
=
True
,
messages
=
inps
,
model
=
self
.
model
,
)
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
.
message
.
content
res
[
key
].
append
(
s
)
self
.
cache_hook
.
add_partial
(
"multiple_choice_gpt"
,
context
,
s
)
pbar
.
update
(
1
)
# reorder this group of results back to original unsorted form
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
return
grouper
.
get_original
(
res
)
def
loglikelihood
(
self
,
requests
,
disable_tqdm
:
bool
=
False
):
raise
NotImplementedError
(
"No support for logits."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment