Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e7cd7d68
"docs/source/vscode:/vscode.git/clone" did not exist on "b52f7756fbcf6669dbe92e97e11415c4084cf881"
Commit
e7cd7d68
authored
Dec 19, 2023
by
lintangsutawika
Browse files
sample metrics that have both sample-wise and set-wise operations
parent
08fcf1fe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
166 additions
and
136 deletions
+166
-136
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+166
-136
No files found.
lm_eval/api/metrics.py
View file @
e7cd7d68
...
...
@@ -13,6 +13,30 @@ import logging
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
class
BaseMetric
:
def
__init__
(
self
,
aggregation
=
None
,
)
->
None
:
self
.
aggregation
=
aggregation
def
__call__
(
self
,
*
items
):
sample_wise_score
=
self
.
sample_wise_compute
(
*
items
)
if
self
.
aggregation
is
not
None
:
return
self
.
aggregation
(
sample_wise_score
)
else
:
return
self
.
set_wise_compute
(
sample_wise_score
)
def
sample_wise_compute
(
self
,
*
items
):
return
items
def
set_wise_compute
(
self
,
*
items
):
return
items
# Register Aggregations First
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
...
...
@@ -24,21 +48,28 @@ def median(arr):
return
arr
[
len
(
arr
)
//
2
]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_metric
(
metric
=
"perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood"
,
)
class
PerplexityMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
ll
,
is_greedy
):
return
ll
@
register_aggregation
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
def
set_wise_compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"bits_per_byte"
)
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
@
register_metric
(
metric
=
"acc"
,
higher_is_better
=
True
,
output_type
=
"loglikelihood"
,
aggregation
=
"mean"
,
)
class
LoglikelihoodAccMetric
(
BaseMetric
):
def
__call__
(
self
,
ll
,
is_greedy
):
return
int
(
is_greedy
)
@
register_aggregation
(
"f1"
)
...
...
@@ -109,87 +140,86 @@ def ter(items):
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
@
register_metric
(
metric
=
"acc"
,
higher_is_better
=
True
,
output_type
=
[
"loglikelihood"
,
"multiple_choice"
],
aggregation
=
"mean"
,
)
def
acc_fn
(
items
):
# This is a passthrough function
return
items
#
@register_metric(
#
metric="acc",
#
higher_is_better=True,
#
output_type=["loglikelihood", "multiple_choice"],
#
aggregation="mean",
#
)
#
def acc_fn(items): # This is a passthrough function
#
return items
@
register_metric
(
metric
=
"acc_norm"
,
higher_is_better
=
True
,
output_type
=
[
"loglikelihood"
,
"multiple_choice"
],
aggregation
=
"mean"
,
)
def
acc_norm_fn
(
items
):
# This is a passthrough function
return
items
#
@register_metric(
#
metric="acc_norm",
#
higher_is_better=True,
#
output_type=["loglikelihood", "multiple_choice"],
#
aggregation="mean",
#
)
#
def acc_norm_fn(items): # This is a passthrough function
#
return items
@
register_metric
(
metric
=
"acc_mutual_info"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
aggregation
=
"mean"
,
)
def
acc_mutual_info_fn
(
items
):
# This is a passthrough function
return
items
#
@register_metric(
#
metric="acc_mutual_info",
#
higher_is_better=True,
#
output_type="multiple_choice",
#
aggregation="mean",
#
)
#
def acc_mutual_info_fn(items): # This is a passthrough function
#
return items
exact_match
=
evaluate
.
load
(
"exact_match"
)
@
register_metric
(
metric
=
"exact_match"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
aggregation
=
"mean"
,
)
def
exact_match_fn
(
**
kwargs
):
return
exact_match
.
compute
(
**
kwargs
)
@
register_metric
(
metric
=
"perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood"
,
aggregation
=
"perplexity"
,
)
def
perplexity_fn
(
items
):
# This is a passthrough function
return
items
# @register_metric(
# metric="exact_match",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="mean",
# )
# def exact_match_fn(**kwargs):
# return exact_match.compute(**kwargs)
@
register_metric
(
metric
=
"word_perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
word_perplexity_fn
(
items
):
# This is a passthrough function
return
items
class
BytePerplexityMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
loglikelihood
,
_words
def
set_wise_compute
(
self
,
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
metric
=
"byte_perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
byte_perplexity_fn
(
items
):
# This is a passthrough function
return
items
class
BytePerplexityMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
loglikelihood
,
_bytes
def
set_wise_compute
(
self
,
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
metric
=
"bits_per_byte"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"bits_per_byte"
,
)
def
bits_per_byte_fn
(
items
):
# This is a passthrough function
return
items
class
BitsPerByteMetric
(
BaseMetric
):
def
sample_wise_compute
(
self
,
loglikelihood
,
_words
,
_bytes
):
return
loglikelihood
,
_bytes
def
set_wise_compute
(
self
,
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
def
pop_stddev
(
arr
):
...
...
@@ -206,79 +236,79 @@ def mean_stderr(arr):
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
@
register_metric
(
metric
=
"mcc"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
aggregation
=
"matthews_corrcoef"
,
)
def
mcc_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"f1"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
aggregation
=
"f1"
,
)
def
f1_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"bleu"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
aggregation
=
"bleu"
,
)
def
bleu_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"chrf"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
aggregation
=
"chrf"
,
)
def
chrf_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"ter"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
aggregation
=
"ter"
,
)
def
ter_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"acc_all"
,
higher_is_better
=
True
,
output_type
=
"loglikelihood"
,
aggregation
=
"mean"
,
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
question_id
=
doc
[
"idx"
][
"question"
]
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
#
@register_metric(
#
metric="mcc",
#
higher_is_better=True,
#
output_type="multiple_choice",
#
aggregation="matthews_corrcoef",
#
)
#
def mcc_fn(items): # This is a passthrough function
#
return items
#
@register_metric(
#
metric="f1",
#
higher_is_better=True,
#
output_type="multiple_choice",
#
aggregation="f1",
#
)
#
def f1_fn(items): # This is a passthrough function
#
return items
#
@register_metric(
#
metric="bleu",
#
higher_is_better=True,
#
output_type="generate_until",
#
aggregation="bleu",
#
)
#
def bleu_fn(items): # This is a passthrough function
#
return items
#
@register_metric(
#
metric="chrf",
#
higher_is_better=True,
#
output_type="generate_until",
#
aggregation="chrf",
#
)
#
def chrf_fn(items): # This is a passthrough function
#
return items
#
@register_metric(
#
metric="ter",
#
higher_is_better=True,
#
output_type="generate_until",
#
aggregation="ter",
#
)
#
def ter_fn(items): # This is a passthrough function
#
return items
#
@register_metric(
#
metric="acc_all",
#
higher_is_better=True,
#
output_type="loglikelihood",
#
aggregation="mean",
#
)
#
def acc_all(items):
#
# Only count as correct if all answers are labeled correctly for each question
#
question_scoring_dict = {}
#
preds = list(zip(*items))[0]
#
docs = list(zip(*items))[1]
#
for doc, pred in zip(docs, preds):
#
paragraph_id = doc["idx"]["paragraph"]
#
question_id = doc["idx"]["question"]
#
if (paragraph_id, question_id) not in question_scoring_dict:
#
question_scoring_dict[(paragraph_id, question_id)] = []
#
gold_label = doc["label"] == 1
#
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
#
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
#
return acc
def
acc_all_stderr
(
items
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment