Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
039832e5
Commit
039832e5
authored
Dec 28, 2023
by
lintangsutawika
Browse files
removed passthrough fn
parent
3888193d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
185 deletions
+88
-185
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+88
-185
No files found.
lm_eval/api/metrics.py
View file @
039832e5
import
logging
import
math
import
random
from
collections.abc
import
Iterable
import
abc
import
evaluate
import
numpy
as
np
import
sacrebleu
import
sklearn.metrics
import
random
import
evaluate
from
lm_eval.api.registry
import
register_metric
,
register_aggregation
from
lm_eval.api.registry
import
register_metric
import
logging
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
class
BaseMetric
:
def
__init__
(
self
,
)
->
None
:
@
abc
.
abstractmethod
def
update
(
self
,
*
items
):
pass
@
abc
.
abstractmethod
def
compute
(
self
,
*
items
):
pass
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
...
...
@@ -37,32 +22,43 @@ def median(arr):
return
arr
[
len
(
arr
)
//
2
]
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
@
register_metric
(
metric
=
"perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood"
,
)
class
PerplexityMetric
(
BaseMetric
):
def
update
(
self
,
ll
,
is_greedy
):
return
ll
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
def
compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_metric
(
metric
=
[
"word_perplexity"
,
"byte_perplexity"
],
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
)
def
weighted_perplexity
(
items
):
# This is a passthrough function
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
metric
=
"
acc
"
,
higher_is_better
=
Tru
e
,
output_type
=
"loglikelihood"
,
metric
=
"
bits_per_byte
"
,
higher_is_better
=
Fals
e
,
output_type
=
"loglikelihood
_rolling
"
,
)
class
LoglikelihoodAccMetric
(
BaseMetric
):
def
update
(
self
,
ll
,
is_greedy
):
return
int
(
is_greedy
)
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
def
compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"f1"
)
@
register_metric
(
metric
=
"f1"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
...
...
@@ -72,16 +68,23 @@ def f1_score(items):
return
np
.
max
(
fscore
)
@
register_aggregation
(
"matthews_corrcoef"
)
@
register_metric
(
metric
=
"mcc"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
)
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
# print(preds)
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_aggregation
(
"bleu"
)
@
register_metric
(
metric
=
"bleu"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
...
...
@@ -99,7 +102,11 @@ def bleu(items):
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_aggregation
(
"chrf"
)
@
register_metric
(
metric
=
"chrf"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
...
...
@@ -114,7 +121,11 @@ def chrf(items):
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_aggregation
(
"ter"
)
@
register_metric
(
metric
=
"ter"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
...
...
@@ -130,86 +141,34 @@ def ter(items):
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
# @register_metric(
# metric="acc",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_norm",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_norm_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_mutual_info",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="mean",
# )
# def acc_mutual_info_fn(items): # This is a passthrough function
# return items
exact_match
=
evaluate
.
load
(
"exact_match"
)
# @register_metric(
# metric="exact_match",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="mean",
# )
# def exact_match_fn(**kwargs):
# return exact_match.compute(**kwargs)
@
register_metric
(
metric
=
"word_perplexity"
,
higher_is_better
=
Fals
e
,
output_type
=
"loglikelihood
_rolling"
,
metric
=
[
"acc"
,
"acc_norm"
]
,
higher_is_better
=
Tru
e
,
output_type
=
[
"loglikelihood
"
,
"multiple_choice"
]
,
)
class
BytePerplexityMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
loglikelihood
,
_words
def
set_wise_compute
(
self
,
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
def
aggregate_acc_fn
(
items
):
return
mean
(
items
)
@
register_metric
(
metric
=
"
byte_perplexity
"
,
higher_is_better
=
Fals
e
,
output_type
=
"
loglikelihood_rolling
"
,
metric
=
"
acc_mutual_info
"
,
higher_is_better
=
Tru
e
,
output_type
=
"
multiple_choice
"
,
)
class
BytePerplexityMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_by
tes
)
:
return
loglikelihood
,
_bytes
def
acc_mutual_info_fn
(
items
):
return
mean
(
i
te
m
s
)
def
set_wise_compute
(
self
,
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
exact_match
=
evaluate
.
load
(
"exact_match"
)
@
register_metric
(
metric
=
"
bits_per_byte
"
,
higher_is_better
=
Fals
e
,
output_type
=
"
loglikelihood_rolling
"
,
metric
=
"
exact_match
"
,
higher_is_better
=
Tru
e
,
output_type
=
"
generate_until
"
,
)
class
BitsPerByteMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
loglikelihood
,
_bytes
def
set_wise_compute
(
self
,
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
def
exact_match_fn
(
**
kwargs
):
return
exact_match
.
compute
(
**
kwargs
)
def
pop_stddev
(
arr
):
...
...
@@ -226,79 +185,28 @@ def mean_stderr(arr):
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
# @register_metric(
# metric="mcc",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="matthews_corrcoef",
# )
# def mcc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="f1",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="f1",
# )
# def f1_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="bleu",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="bleu",
# )
# def bleu_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="chrf",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="chrf",
# )
# def chrf_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="ter",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="ter",
# )
# def ter_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_all",
# higher_is_better=True,
# output_type="loglikelihood",
# aggregation="mean",
# )
# def acc_all(items):
# # Only count as correct if all answers are labeled correctly for each question
# question_scoring_dict = {}
# preds = list(zip(*items))[0]
# docs = list(zip(*items))[1]
# for doc, pred in zip(docs, preds):
# paragraph_id = doc["idx"]["paragraph"]
# question_id = doc["idx"]["question"]
# if (paragraph_id, question_id) not in question_scoring_dict:
# question_scoring_dict[(paragraph_id, question_id)] = []
# gold_label = doc["label"] == 1
# question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
# acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
# return acc
@
register_metric
(
metric
=
"acc_all"
,
higher_is_better
=
True
,
output_type
=
"loglikelihood"
,
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
question_id
=
doc
[
"idx"
][
"question"
]
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
def
acc_all_stderr
(
items
):
...
...
@@ -328,11 +236,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment