Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f107ae29
Commit
f107ae29
authored
Apr 27, 2023
by
lintangsutawika
Browse files
ported changes here
parent
2a9da9fb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
17 deletions
+33
-17
lm_eval/api/task.py
lm_eval/api/task.py
+33
-17
No files found.
lm_eval/api/task.py
View file @
f107ae29
...
@@ -389,25 +389,30 @@ class ConfigurableTask(Task):
...
@@ -389,25 +389,30 @@ class ConfigurableTask(Task):
self
.
_metric_list
=
{}
self
.
_metric_list
=
{}
self
.
_aggregation_list
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
self
.
_higher_is_better
=
{}
for
(
metric_name
,
aggregation
,
higher_is_better
)
in
self
.
_config
.
metric_list
:
self
.
_metric_kwargs
=
{}
for
metric_config
in
self
.
_config
.
metric_list
:
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
aggregation
)
metric_name
=
metric_config
[
'name'
]
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
aggregation
=
metric_config
[
'aggregation'
]
higher_is_better
=
metric_config
[
'higher_is_better'
]
kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
'name'
,
'aggregation'
,
'higher_is_better'
]}
self
.
_metric_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
if
metric_name
in
METRIC_REGISTRY
.
keys
():
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
else
:
try
:
metric_object
=
evaluate
.
load
(
metric_name
)
self
.
_metric_list
[
metric_name
]
=
metric_object
self
.
_metric_kwargs
[
metric_name
]
=
kwargs
# if metric_name in METRIC_REGISTRY.keys():
except
Exception
as
ex
:
# self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
raise
Warning
(
# else:
"{} not found in the evaluate library!"
.
format
(
metric_name
),
# try:
"Please check https://huggingface.co/evaluate-metric"
,
# metric_object = evaluate.load(metric_name)
)
# self._metric_list[metric_name] = metric_object
# except Exception as ex:
# raise Warning(
# "{} not found in the evaluate library!".format(metric_name),
# "Please check https://huggingface.co/evaluate-metric",
# )
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -468,8 +473,19 @@ class ConfigurableTask(Task):
...
@@ -468,8 +473,19 @@ class ConfigurableTask(Task):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
if
self
.
OUTPUT_TYPE
==
"greedy_until"
:
if
self
.
output_type
==
"loglikelihood"
:
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
(
ctx
,
"
\n\n
"
),
id_
=
0
,
**
kwargs
)
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
elif
self
.
output_type
==
"loglikelihood_rolling"
:
arguments
=
(
self
.
doc_to_target
(
doc
),)
elif
self
.
output_type
==
"greedy_until"
:
arguments
=
(
ctx
,
"
\n\n
"
)
return
Instance
(
request_type
=
self
.
output_type
,
doc
=
doc
,
arguments
=
arguments
,
**
kwargs
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment