Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
039832e5
Commit
039832e5
authored
Dec 28, 2023
by
lintangsutawika
Browse files
removed passthrough fn
parent
3888193d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
185 deletions
+88
-185
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+88
-185
No files found.
lm_eval/api/metrics.py
View file @
039832e5
import
logging
import
math
import
math
import
random
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
abc
import
evaluate
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
import
sklearn.metrics
import
sklearn.metrics
import
random
import
evaluate
from
lm_eval.api.registry
import
register_metric
,
register_aggregation
from
lm_eval.api.registry
import
register_metric
import
logging
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
class
BaseMetric
:
def
__init__
(
self
,
)
->
None
:
@
abc
.
abstractmethod
def
update
(
self
,
*
items
):
pass
@
abc
.
abstractmethod
def
compute
(
self
,
*
items
):
pass
def
mean
(
arr
):
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
...
@@ -37,32 +22,43 @@ def median(arr):
...
@@ -37,32 +22,43 @@ def median(arr):
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
@
register_metric
(
@
register_metric
(
metric
=
"perplexity"
,
metric
=
"perplexity"
,
higher_is_better
=
False
,
higher_is_better
=
False
,
output_type
=
"loglikelihood"
,
output_type
=
"loglikelihood"
,
)
)
class
PerplexityMetric
(
BaseMetric
):
def
perplexity
(
items
):
def
update
(
self
,
ll
,
is_greedy
):
return
math
.
exp
(
-
mean
(
items
))
return
ll
def
compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_metric
(
metric
=
[
"word_perplexity"
,
"byte_perplexity"
],
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
)
def
weighted_perplexity
(
items
):
# This is a passthrough function
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
@
register_metric
(
metric
=
"
acc
"
,
metric
=
"
bits_per_byte
"
,
higher_is_better
=
Tru
e
,
higher_is_better
=
Fals
e
,
output_type
=
"loglikelihood"
,
output_type
=
"loglikelihood
_rolling
"
,
)
)
class
LoglikelihoodAccMetric
(
BaseMetric
):
def
bits_per_byte
(
items
):
def
update
(
self
,
ll
,
is_greedy
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
return
int
(
is_greedy
)
def
compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"f1"
)
@
register_metric
(
metric
=
"f1"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
)
def
f1_score
(
items
):
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
golds
=
unzipped_list
[
0
]
...
@@ -72,16 +68,23 @@ def f1_score(items):
...
@@ -72,16 +68,23 @@ def f1_score(items):
return
np
.
max
(
fscore
)
return
np
.
max
(
fscore
)
@
register_aggregation
(
"matthews_corrcoef"
)
@
register_metric
(
metric
=
"mcc"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
)
def
matthews_corrcoef
(
items
):
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
preds
=
unzipped_list
[
1
]
# print(preds)
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_aggregation
(
"bleu"
)
@
register_metric
(
metric
=
"bleu"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
bleu
(
items
):
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
for evaluating a generated sentence to a reference sentence. It counts matching
...
@@ -99,7 +102,11 @@ def bleu(items):
...
@@ -99,7 +102,11 @@ def bleu(items):
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_aggregation
(
"chrf"
)
@
register_metric
(
metric
=
"chrf"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
chrf
(
items
):
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
based on character n-gram precision and recall enhanced with word n-grams.
...
@@ -114,7 +121,11 @@ def chrf(items):
...
@@ -114,7 +121,11 @@ def chrf(items):
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_aggregation
(
"ter"
)
@
register_metric
(
metric
=
"ter"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
ter
(
items
):
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
measures the number of edits required to change a system output into one
...
@@ -130,86 +141,34 @@ def ter(items):
...
@@ -130,86 +141,34 @@ def ter(items):
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
# @register_metric(
# metric="acc",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_norm",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_norm_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_mutual_info",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="mean",
# )
# def acc_mutual_info_fn(items): # This is a passthrough function
# return items
exact_match
=
evaluate
.
load
(
"exact_match"
)
# @register_metric(
# metric="exact_match",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="mean",
# )
# def exact_match_fn(**kwargs):
# return exact_match.compute(**kwargs)
@
register_metric
(
@
register_metric
(
metric
=
"word_perplexity"
,
metric
=
[
"acc"
,
"acc_norm"
]
,
higher_is_better
=
Fals
e
,
higher_is_better
=
Tru
e
,
output_type
=
"loglikelihood
_rolling"
,
output_type
=
[
"loglikelihood
"
,
"multiple_choice"
]
,
)
)
class
BytePerplexityMetric
(
BaseMetric
):
def
aggregate_acc_fn
(
items
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
mean
(
items
)
return
loglikelihood
,
_words
def
set_wise_compute
(
self
,
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
@
register_metric
(
metric
=
"
byte_perplexity
"
,
metric
=
"
acc_mutual_info
"
,
higher_is_better
=
Fals
e
,
higher_is_better
=
Tru
e
,
output_type
=
"
loglikelihood_rolling
"
,
output_type
=
"
multiple_choice
"
,
)
)
class
BytePerplexityMetric
(
BaseMetric
):
def
acc_mutual_info_fn
(
items
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_by
tes
)
:
return
mean
(
i
te
m
s
)
return
loglikelihood
,
_bytes
def
set_wise_compute
(
self
,
items
):
exact_match
=
evaluate
.
load
(
"exact_match"
)
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
@
register_metric
(
metric
=
"
bits_per_byte
"
,
metric
=
"
exact_match
"
,
higher_is_better
=
Fals
e
,
higher_is_better
=
Tru
e
,
output_type
=
"
loglikelihood_rolling
"
,
output_type
=
"
generate_until
"
,
)
)
class
BitsPerByteMetric
(
BaseMetric
):
def
exact_match_fn
(
**
kwargs
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
exact_match
.
compute
(
**
kwargs
)
return
loglikelihood
,
_bytes
def
set_wise_compute
(
self
,
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
def
pop_stddev
(
arr
):
def
pop_stddev
(
arr
):
...
@@ -226,79 +185,28 @@ def mean_stderr(arr):
...
@@ -226,79 +185,28 @@ def mean_stderr(arr):
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
# @register_metric(
@
register_metric
(
# metric="mcc",
metric
=
"acc_all"
,
# higher_is_better=True,
higher_is_better
=
True
,
# output_type="multiple_choice",
output_type
=
"loglikelihood"
,
# aggregation="matthews_corrcoef",
)
# )
def
acc_all
(
items
):
# def mcc_fn(items): # This is a passthrough function
# Only count as correct if all answers are labeled correctly for each question
# return items
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
# @register_metric(
# metric="f1",
for
doc
,
pred
in
zip
(
docs
,
preds
):
# higher_is_better=True,
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
# output_type="multiple_choice",
question_id
=
doc
[
"idx"
][
"question"
]
# aggregation="f1",
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
# )
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
# def f1_fn(items): # This is a passthrough function
# return items
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
# @register_metric(
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
# metric="bleu",
return
acc
# higher_is_better=True,
# output_type="generate_until",
# aggregation="bleu",
# )
# def bleu_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="chrf",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="chrf",
# )
# def chrf_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="ter",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="ter",
# )
# def ter_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_all",
# higher_is_better=True,
# output_type="loglikelihood",
# aggregation="mean",
# )
# def acc_all(items):
# # Only count as correct if all answers are labeled correctly for each question
# question_scoring_dict = {}
# preds = list(zip(*items))[0]
# docs = list(zip(*items))[1]
# for doc, pred in zip(docs, preds):
# paragraph_id = doc["idx"]["paragraph"]
# question_id = doc["idx"]["question"]
# if (paragraph_id, question_id) not in question_scoring_dict:
# question_scoring_dict[(paragraph_id, question_id)] = []
# gold_label = doc["label"] == 1
# question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
# acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
# return acc
def
acc_all_stderr
(
items
):
def
acc_all_stderr
(
items
):
...
@@ -328,11 +236,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
...
@@ -328,11 +236,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
def
is_non_str_iterable
(
obj
):
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment