Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
de46fb9a
Commit
de46fb9a
authored
Jan 02, 2024
by
lintangsutawika
Browse files
reformat
parent
dfb036b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
32 deletions
+11
-32
lm_eval/api/registry.py
lm_eval/api/registry.py
+8
-9
lm_eval/api/task.py
lm_eval/api/task.py
+3
-23
No files found.
lm_eval/api/registry.py
View file @
de46fb9a
import
os
import
logging
import
evaluate
import
collections
import
collections
import
logging
from
functools
import
partial
from
functools
import
partial
import
evaluate
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
MODEL_REGISTRY
=
{}
MODEL_REGISTRY
=
{}
...
@@ -92,9 +93,9 @@ def register_metric(
...
@@ -92,9 +93,9 @@ def register_metric(
):
):
# TODO: do we want to enforce a certain interface to registered metrics?
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
def
decorate
(
fn
):
if
typ
e
(
metric
)
==
str
:
if
isinstanc
e
(
metric
,
str
)
:
metric_list
=
[
metric
]
metric_list
=
[
metric
]
elif
typ
e
(
metric
)
==
list
:
elif
isinstanc
e
(
metric
,
list
)
:
metric_list
=
metric
metric_list
=
metric
for
_metric
in
metric_list
:
for
_metric
in
metric_list
:
...
@@ -107,9 +108,9 @@ def register_metric(
...
@@ -107,9 +108,9 @@ def register_metric(
METRIC_REGISTRY
[
_metric
][
"higher_is_better"
]
=
higher_is_better
METRIC_REGISTRY
[
_metric
][
"higher_is_better"
]
=
higher_is_better
if
output_type
is
not
None
:
if
output_type
is
not
None
:
if
typ
e
(
output_type
)
==
str
:
if
isinstanc
e
(
output_type
,
str
)
:
output_type_list
=
[
output_type
]
output_type_list
=
[
output_type
]
elif
typ
e
(
output_type
)
==
list
:
elif
isinstanc
e
(
output_type
,
list
)
:
output_type_list
=
output_type
output_type_list
=
output_type
for
_output_type
in
output_type_list
:
for
_output_type
in
output_type_list
:
...
@@ -121,7 +122,6 @@ def register_metric(
...
@@ -121,7 +122,6 @@ def register_metric(
def
get_metric
(
name
):
def
get_metric
(
name
):
if
name
in
METRIC_REGISTRY
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
return
METRIC_REGISTRY
[
name
]
else
:
else
:
...
@@ -133,7 +133,6 @@ def get_evaluate(name, **kwargs):
...
@@ -133,7 +133,6 @@ def get_evaluate(name, **kwargs):
class
HFEvaluateAdaptor
:
class
HFEvaluateAdaptor
:
def
__init__
(
self
,
name
,
**
kwargs
):
def
__init__
(
self
,
name
,
**
kwargs
):
self
.
name
=
name
self
.
name
=
name
metric_object
=
evaluate
.
load
(
name
)
metric_object
=
evaluate
.
load
(
name
)
self
.
hf_evaluate_fn
=
partial
(
metric_object
.
compute
,
**
kwargs
)
self
.
hf_evaluate_fn
=
partial
(
metric_object
.
compute
,
**
kwargs
)
...
...
lm_eval/api/task.py
View file @
de46fb9a
...
@@ -18,31 +18,13 @@ from lm_eval.api.metrics import (
...
@@ -18,31 +18,13 @@ from lm_eval.api.metrics import (
bits_per_byte
,
bits_per_byte
,
mean
,
mean
,
weighted_perplexity
,
weighted_perplexity
,
<<<<<<<
HEAD
<<<<<<<
HEAD
=======
>>>>>>>
cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
bits_per_byte
,
)
from
lm_eval.api.registry
import
(
get_metric
,
get_evaluate
,
get_aggregation
,
METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
<<<<<<<
HEAD
=======
)
)
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
METRIC_REGISTRY
,
get_aggregation
,
get_aggregation
,
get_evaluate
,
get_metric
,
get_metric
,
get_metric_aggregation
,
is_higher_better
,
>>>>>>>
4
d10ad56b1ffe569467eee2297e2317c99313118
=======
>>>>>>>
cda25fef4e1df2f4bc2dab3ec6668ae9f5bf7296
)
)
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
...
@@ -603,7 +585,7 @@ class ConfigurableTask(Task):
...
@@ -603,7 +585,7 @@ class ConfigurableTask(Task):
metric_fn
=
metric_name
.
__call__
metric_fn
=
metric_name
.
__call__
metric_name
=
metric_name
.
__name__
metric_name
=
metric_name
.
__name__
else
:
else
:
assert
typ
e
(
metric_name
)
==
str
assert
isinstanc
e
(
metric_name
,
str
)
if
use_hf_evaluate
:
if
use_hf_evaluate
:
metric_fn
=
get_evaluate
(
metric_name
,
**
kwargs
)
metric_fn
=
get_evaluate
(
metric_name
,
**
kwargs
)
elif
metric_name
in
METRIC_REGISTRY
:
elif
metric_name
in
METRIC_REGISTRY
:
...
@@ -620,7 +602,6 @@ class ConfigurableTask(Task):
...
@@ -620,7 +602,6 @@ class ConfigurableTask(Task):
self
.
_aggregation_list
[
metric_name
]
=
metric_fn
self
.
_aggregation_list
[
metric_name
]
=
metric_fn
else
:
else
:
if
"aggregation"
in
metric_config
:
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
agg_name
=
metric_config
[
"aggregation"
]
if
isinstance
(
agg_name
,
str
):
if
isinstance
(
agg_name
,
str
):
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
...
@@ -1028,7 +1009,6 @@ class ConfigurableTask(Task):
...
@@ -1028,7 +1009,6 @@ class ConfigurableTask(Task):
)
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
# Process results returns 1 of X things per doc/results
# Process results returns 1 of X things per doc/results
# 1. A score
# 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction
# 2. Components to be processed later to obtained a score. such as gold and prediction
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment