Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8add5ed6
Commit
8add5ed6
authored
Jul 24, 2023
by
lintangsutawika
Browse files
Use num_fewshot set in yaml and show warning if it's being overwritten by argparse
parent
2820042d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
9 deletions
+32
-9
lm_eval/api/task.py
lm_eval/api/task.py
+3
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+13
-5
lm_eval/utils.py
lm_eval/utils.py
+15
-3
main.py
main.py
+1
-1
No files found.
lm_eval/api/task.py
View file @
8add5ed6
...
@@ -130,6 +130,9 @@ class TaskConfig(dict):
...
@@ -130,6 +130,9 @@ class TaskConfig(dict):
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
):
def
to_dict
(
self
):
"""dumps the current config as a dictionary object, as a printable format.
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
null fields will not be printed.
...
...
lm_eval/evaluator.py
View file @
8add5ed6
...
@@ -35,7 +35,7 @@ def simple_evaluate(
...
@@ -35,7 +35,7 @@ def simple_evaluate(
model
,
model
,
model_args
=
None
,
model_args
=
None
,
tasks
=
[],
tasks
=
[],
num_fewshot
=
0
,
num_fewshot
=
None
,
batch_size
=
None
,
batch_size
=
None
,
max_batch_size
=
None
,
max_batch_size
=
None
,
device
=
None
,
device
=
None
,
...
@@ -112,7 +112,17 @@ def simple_evaluate(
...
@@ -112,7 +112,17 @@ def simple_evaluate(
+
"_rank"
+
str
(
lm
.
rank
)
+
".db"
,
+
"_rank"
+
str
(
lm
.
rank
)
+
".db"
,
)
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
,
num_fewshot
=
num_fewshot
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
for
task_name
in
task_dict
.
keys
():
config
=
task_dict
[
task_name
].
_config
if
num_fewshot
is
not
None
:
if
config
[
"num_fewshot"
]
>
0
:
default_num_fewshot
=
config
[
"num_fewshot"
]
eval_logger
.
warning
(
f
"Overwriting default num_fewshot of
{
task_name
}
from
{
default_num_fewshot
}
to
{
num_fewshot
}
"
)
task_dict
[
task_name
].
_config
.
__setitem__
(
"num_fewshot"
,
num_fewshot
)
if
check_integrity
:
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
run_task_tests
(
task_list
=
tasks
)
...
@@ -134,7 +144,6 @@ def simple_evaluate(
...
@@ -134,7 +144,6 @@ def simple_evaluate(
if
isinstance
(
model
,
str
)
if
isinstance
(
model
,
str
)
else
model
.
model
.
config
.
_name_or_path
,
else
model
.
model
.
config
.
_name_or_path
,
"model_args"
:
model_args
,
"model_args"
:
model_args
,
"num_fewshot"
:
num_fewshot
,
"batch_size"
:
batch_size
,
"batch_size"
:
batch_size
,
"batch_sizes"
:
list
(
lm
.
batch_sizes
.
values
())
"batch_sizes"
:
list
(
lm
.
batch_sizes
.
values
())
if
hasattr
(
lm
,
"batch_sizes"
)
if
hasattr
(
lm
,
"batch_sizes"
)
...
@@ -169,8 +178,6 @@ def evaluate(
...
@@ -169,8 +178,6 @@ def evaluate(
Language Model
Language Model
:param task_dict: dict[str, Task]
:param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
:param bootstrap_iters:
...
@@ -359,6 +366,7 @@ def evaluate(
...
@@ -359,6 +366,7 @@ def evaluate(
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
+
","
+
key
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_name
][
metric
+
","
+
key
]
=
task
.
aggregation
()[
metric
](
items
)
# results[task_name]['num_fewshot'] = configs[task_name]
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
...
...
lm_eval/utils.py
View file @
8add5ed6
...
@@ -265,9 +265,19 @@ def make_table(result_dict):
...
@@ -265,9 +265,19 @@ def make_table(result_dict):
md_writer
=
MarkdownTableWriter
()
md_writer
=
MarkdownTableWriter
()
latex_writer
=
LatexTableWriter
()
latex_writer
=
LatexTableWriter
()
md_writer
.
headers
=
[
"Task"
,
"Version"
,
"Filter"
,
"Metric"
,
"Value"
,
""
,
"Stderr"
]
md_writer
.
headers
=
[
"Task"
,
"Fewshot"
,
"Version"
,
"Filter"
,
"Metric"
,
"Value"
,
""
,
"Stderr"
,
]
latex_writer
.
headers
=
[
latex_writer
.
headers
=
[
"Task"
,
"Task"
,
"Fewshot"
,
"Version"
,
"Version"
,
"Filter"
,
"Filter"
,
"Metric"
,
"Metric"
,
...
@@ -280,6 +290,7 @@ def make_table(result_dict):
...
@@ -280,6 +290,7 @@ def make_table(result_dict):
for
k
,
dic
in
result_dict
[
"results"
].
items
():
for
k
,
dic
in
result_dict
[
"results"
].
items
():
version
=
result_dict
[
"versions"
][
k
]
version
=
result_dict
[
"versions"
][
k
]
n
=
str
(
result_dict
[
"configs"
][
k
][
"num_fewshot"
])
for
(
mf
),
v
in
dic
.
items
():
for
(
mf
),
v
in
dic
.
items
():
m
,
_
,
f
=
mf
.
partition
(
","
)
m
,
_
,
f
=
mf
.
partition
(
","
)
if
m
.
endswith
(
"_stderr"
):
if
m
.
endswith
(
"_stderr"
):
...
@@ -287,10 +298,11 @@ def make_table(result_dict):
...
@@ -287,10 +298,11 @@ def make_table(result_dict):
if
m
+
"_stderr"
+
","
+
f
in
dic
:
if
m
+
"_stderr"
+
","
+
f
in
dic
:
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
values
.
append
([
k
,
version
,
f
,
m
,
"%.4f"
%
v
,
"±"
,
"%.4f"
%
se
])
values
.
append
([
k
,
n
,
version
,
f
,
m
,
"%.4f"
%
v
,
"±"
,
"%.4f"
%
se
])
else
:
else
:
values
.
append
([
k
,
version
,
f
,
m
,
"%.4f"
%
v
,
""
,
""
])
values
.
append
([
k
,
n
,
version
,
f
,
m
,
"%.4f"
%
v
,
""
,
""
])
k
=
""
k
=
""
n
=
""
version
=
""
version
=
""
md_writer
.
value_matrix
=
values
md_writer
.
value_matrix
=
values
latex_writer
.
value_matrix
=
values
latex_writer
.
value_matrix
=
values
...
...
main.py
View file @
8add5ed6
...
@@ -28,7 +28,7 @@ def parse_args():
...
@@ -28,7 +28,7 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
"--num_fewshot"
,
"--num_fewshot"
,
type
=
int
,
type
=
int
,
default
=
0
,
default
=
None
,
help
=
"Number of examples in few-shot context"
,
help
=
"Number of examples in few-shot context"
,
)
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
# TODO: only integers
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
# TODO: only integers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment