Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6be66284
Unverified
Commit
6be66284
authored
Jun 15, 2023
by
Lintang Sutawika
Committed by
GitHub
Jun 15, 2023
Browse files
Merge branch 'big-refactor' into more-docs
parents
400c0199
0c53ff49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
16 deletions
+30
-16
lm_eval/api/registry.py
lm_eval/api/registry.py
+3
-1
lm_eval/api/task.py
lm_eval/api/task.py
+19
-9
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+4
-0
main.py
main.py
+4
-6
No files found.
lm_eval/api/registry.py
View file @
6be66284
...
@@ -36,7 +36,7 @@ def get_model(model_name):
...
@@ -36,7 +36,7 @@ def get_model(model_name):
TASK_REGISTRY
=
{}
TASK_REGISTRY
=
{}
GROUP_REGISTRY
=
{}
GROUP_REGISTRY
=
{}
ALL_TASKS
=
[]
ALL_TASKS
=
set
()
func2task_index
=
{}
func2task_index
=
{}
...
@@ -47,6 +47,7 @@ def register_task(name):
...
@@ -47,6 +47,7 @@ def register_task(name):
),
f
"task named '
{
name
}
' conflicts with existing registered task!"
),
f
"task named '
{
name
}
' conflicts with existing registered task!"
TASK_REGISTRY
[
name
]
=
fn
TASK_REGISTRY
[
name
]
=
fn
ALL_TASKS
.
add
(
name
)
func2task_index
[
fn
.
__name__
]
=
name
func2task_index
[
fn
.
__name__
]
=
name
return
fn
return
fn
...
@@ -60,6 +61,7 @@ def register_group(name):
...
@@ -60,6 +61,7 @@ def register_group(name):
GROUP_REGISTRY
[
name
].
append
(
func_name
)
GROUP_REGISTRY
[
name
].
append
(
func_name
)
else
:
else
:
GROUP_REGISTRY
[
name
]
=
[
func_name
]
GROUP_REGISTRY
[
name
]
=
[
func_name
]
ALL_TASKS
.
add
(
name
)
return
fn
return
fn
return
decorate
return
decorate
...
...
lm_eval/api/task.py
View file @
6be66284
...
@@ -435,7 +435,7 @@ class Task(abc.ABC):
...
@@ -435,7 +435,7 @@ class Task(abc.ABC):
if
num_fewshot
==
0
:
if
num_fewshot
==
0
:
labeled_examples
=
""
labeled_examples
=
""
else
:
else
:
labeled_examples
=
self
.
sampler
.
get_context
(
doc
,
self
.
_config
.
num_fewshot
)
labeled_examples
=
self
.
sampler
.
get_context
(
doc
,
num_fewshot
)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# if self.has_training_docs():
# if self.has_training_docs():
...
@@ -566,15 +566,24 @@ class ConfigurableTask(Task):
...
@@ -566,15 +566,24 @@ class ConfigurableTask(Task):
if
"aggregation"
in
metric_config
:
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
agg_name
=
metric_config
[
"aggregation"
]
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
agg_name
]
if
type
(
agg_name
)
==
str
:
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
agg_name
]
elif
callable
(
agg_name
):
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
"aggregation"
]
else
:
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
metric_agg
=
DEFAULT_AGGREGATION_REGISTRY
[
metric_name
]
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but aggregation is not"
f
"metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default aggregation for
{
metric_name
}
"
f
"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
metric_agg
]
}
"
)
)
self
.
_aggregation_list
[
metric_name
]
=
DEFAULT_AGGREGATION_REGISTRY
[
self
.
_aggregation_list
[
metric_name
]
=
metric_agg
metric_name
]
if
"higher_is_better"
in
metric_config
:
if
"higher_is_better"
in
metric_config
:
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
...
@@ -582,8 +591,9 @@ class ConfigurableTask(Task):
...
@@ -582,8 +591,9 @@ class ConfigurableTask(Task):
]
]
else
:
else
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but higher_is_better is not"
f
"metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default higher_is_better for
{
metric_name
}
"
f
"using default "
f
"higher_is_better=
{
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
}
"
)
)
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
metric_name
...
...
lm_eval/tasks/__init__.py
View file @
6be66284
...
@@ -13,6 +13,7 @@ from lm_eval.api.registry import (
...
@@ -13,6 +13,7 @@ from lm_eval.api.registry import (
register_group
,
register_group
,
TASK_REGISTRY
,
TASK_REGISTRY
,
GROUP_REGISTRY
,
GROUP_REGISTRY
,
ALL_TASKS
,
)
)
...
@@ -39,6 +40,9 @@ def include_task_folder(task_dir):
...
@@ -39,6 +40,9 @@ def include_task_folder(task_dir):
)
)
if
"task"
in
config
:
if
"task"
in
config
:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name
=
"{}"
.
format
(
config
[
"task"
])
task_name
=
"{}"
.
format
(
config
[
"task"
])
register_task
(
task_name
)(
SubClass
)
register_task
(
task_name
)(
SubClass
)
...
...
main.py
View file @
6be66284
...
@@ -4,11 +4,10 @@ import fnmatch
...
@@ -4,11 +4,10 @@ import fnmatch
import
argparse
import
argparse
from
lm_eval
import
evaluator
,
utils
from
lm_eval
import
evaluator
,
utils
from
lm_eval.api.registry
import
GROUP_REGISTRY
,
TASK_REGISTRY
from
lm_eval.api.registry
import
ALL_TASKS
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
.
keys
())
+
list
(
GROUP_REGISTRY
.
keys
()))
class
MultiChoice
:
class
MultiChoice
:
...
@@ -21,9 +20,8 @@ class MultiChoice:
...
@@ -21,9 +20,8 @@ class MultiChoice:
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
eval_logger
.
warning
(
"{} is not in task list."
.
format
(
value
))
eval_logger
.
warning
(
"{} is not in task list."
.
format
(
value
))
eval_logger
.
info
(
f
"Available tasks to choose:"
)
eval_logger
.
info
(
f
"Available tasks to choose:"
)
# for choice in self.choices:
for
choice
in
self
.
choices
:
# eval_logger.info(f" {choice}")
eval_logger
.
info
(
f
" -
{
choice
}
"
)
eval_logger
.
info
(
ALL_TASKS
)
return
True
return
True
def
__iter__
(
self
):
def
__iter__
(
self
):
...
@@ -35,7 +33,7 @@ def parse_args():
...
@@ -35,7 +33,7 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
required
=
True
)
parser
.
add_argument
(
"--model"
,
required
=
True
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
MultiChoice
(
ALL_TASKS
))
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
MultiChoice
(
sorted
(
ALL_TASKS
))
)
parser
.
add_argument
(
"--config"
,
default
=
None
)
parser
.
add_argument
(
"--config"
,
default
=
None
)
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment