Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
66c58194
Commit
66c58194
authored
May 15, 2023
by
lintangsutawika
Browse files
can process doc_to_text and doc_to_target as function
parent
51b795cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
7 deletions
+28
-7
lm_eval/api/task.py
lm_eval/api/task.py
+28
-7
No files found.
lm_eval/api/task.py
View file @
66c58194
...
@@ -12,8 +12,10 @@ import functools
...
@@ -12,8 +12,10 @@ import functools
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
from
lm_eval
import
utils
from
typing
import
Union
from
collections.abc
import
Callable
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
from
lm_eval.api
import
samplers
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
(
from
lm_eval.api.metrics
import
(
...
@@ -42,8 +44,8 @@ class TaskConfig(dict):
...
@@ -42,8 +44,8 @@ class TaskConfig(dict):
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases
:
str
=
""
template_aliases
:
str
=
""
doc_to_text
:
str
=
""
doc_to_text
:
Union
[
Callable
,
str
]
=
None
doc_to_target
:
str
=
""
doc_to_target
:
Union
[
Callable
,
str
]
=
None
num_fewshot
:
int
=
0
num_fewshot
:
int
=
0
...
@@ -66,8 +68,11 @@ class TaskConfig(dict):
...
@@ -66,8 +68,11 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# force prompt-compatibility for some prompt regardless of
# field names in prompt
# field names in prompt
self
.
doc_to_text
=
self
.
template_aliases
+
self
.
doc_to_text
if
type
(
self
.
doc_to_text
)
==
str
:
self
.
doc_to_target
=
self
.
template_aliases
+
self
.
doc_to_target
self
.
doc_to_text
=
self
.
template_aliases
+
self
.
doc_to_text
if
type
(
self
.
doc_to_target
)
==
str
:
self
.
doc_to_target
=
self
.
template_aliases
+
self
.
doc_to_target
# set "task_name" metadata field based on the "primary" name set
# set "task_name" metadata field based on the "primary" name set
if
self
.
names
:
if
self
.
names
:
...
@@ -439,9 +444,11 @@ class ConfigurableTask(Task):
...
@@ -439,9 +444,11 @@ class ConfigurableTask(Task):
self
.
OUTPUT_TYPE
=
self
.
_config
.
output_type
self
.
OUTPUT_TYPE
=
self
.
_config
.
output_type
if
self
.
_config
.
dataset_path
is
not
None
:
if
self
.
_config
.
dataset_path
is
not
None
:
print
(
self
.
_config
.
dataset_path
)
self
.
DATASET_PATH
=
self
.
_config
.
dataset_path
self
.
DATASET_PATH
=
self
.
_config
.
dataset_path
if
self
.
_config
.
dataset_name
is
not
None
:
if
self
.
_config
.
dataset_name
is
not
None
:
print
(
self
.
_config
.
dataset_name
)
self
.
DATASET_NAME
=
self
.
_config
.
dataset_name
self
.
DATASET_NAME
=
self
.
_config
.
dataset_name
if
self
.
_config
.
metric_list
is
not
None
:
if
self
.
_config
.
metric_list
is
not
None
:
...
@@ -546,10 +553,24 @@ class ConfigurableTask(Task):
...
@@ -546,10 +553,24 @@ class ConfigurableTask(Task):
doc_to_text
=
get_prompt
(
self
.
_config
.
use_prompt
)
doc_to_text
=
get_prompt
(
self
.
_config
.
use_prompt
)
else
:
else
:
doc_to_text
=
self
.
_config
.
doc_to_text
doc_to_text
=
self
.
_config
.
doc_to_text
return
utils
.
apply_template
(
doc_to_text
,
doc
)
print
(
doc_to_text
)
if
type
(
doc_to_text
)
==
str
:
return
utils
.
apply_template
(
doc_to_text
,
doc
)
elif
type
(
doc_to_text
)
==
Callable
:
return
doc_to_text
(
doc
)
else
:
raise
TypeError
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
utils
.
apply_template
(
self
.
_config
.
doc_to_target
,
doc
)
doc_to_target
=
self
.
_config
.
doc_to_target
if
type
(
doc_to_target
)
==
str
:
return
utils
.
apply_template
(
doc_to_target
,
doc
)
elif
type
(
doc_to_target
)
==
Callable
:
return
doc_to_target
(
doc
)
else
:
raise
TypeError
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment