Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
a22d8ffa
Commit
a22d8ffa
authored
Jun 07, 2023
by
lintangsutawika
Browse files
modified import orgin
parent
f2166089
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
8 deletions
+14
-8
lm_eval/api/task.py
lm_eval/api/task.py
+14
-8
No files found.
lm_eval/api/task.py
View file @
a22d8ffa
...
@@ -23,18 +23,20 @@ from lm_eval.api.filter import FilterEnsemble
...
@@ -23,18 +23,20 @@ from lm_eval.api.filter import FilterEnsemble
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.metrics
import
(
from
lm_eval.api.metrics
import
(
# get_metric,
# get_aggregation,
mean
,
weighted_perplexity
,
bits_per_byte
,
)
from
lm_eval.api.registry
import
(
METRIC_REGISTRY
,
METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
OUTPUT_TYPE_REGISTRY
,
OUTPUT_TYPE_REGISTRY
,
AGGREGATION_REGISTRY
,
AGGREGATION_REGISTRY
,
HIGHER_IS_BETTER_REGISTRY
,
HIGHER_IS_BETTER_REGISTRY
,
DEFAULT_AGGREGATION_REGISTRY
,
DEFAULT_AGGREGATION_REGISTRY
,
# get_metric,
# get_aggregation,
mean
,
weighted_perplexity
,
bits_per_byte
,
)
)
ALL_OUTPUT_TYPES
=
[
ALL_OUTPUT_TYPES
=
[
...
@@ -504,8 +506,9 @@ class ConfigurableTask(Task):
...
@@ -504,8 +506,9 @@ class ConfigurableTask(Task):
)
)
for
metric_name
in
_metric_list
:
for
metric_name
in
_metric_list
:
self
.
_metric_fn_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
self
.
_metric_fn_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
aggregation
=
DEFAULT_AGGREGATION_REGISTRY
[
metric_name
]
self
.
_aggregation_list
[
metric_name
]
=
DEFAULT_AGGREGATION_REGISTRY
[
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
metric_name
]
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
metric_name
]
]
...
@@ -754,6 +757,9 @@ class ConfigurableTask(Task):
...
@@ -754,6 +757,9 @@ class ConfigurableTask(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
# if callable(self._config.process_results):
# return self._config.process_results(doc, results)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
self
.
_metric_fn_list
.
keys
())
use_metric
=
list
(
self
.
_metric_fn_list
.
keys
())
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment