Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
025547c9
Unverified
Commit
025547c9
authored
Apr 27, 2022
by
Stella Biderman
Committed by
GitHub
Apr 27, 2022
Browse files
Merge pull request #6 from cjlovering/master
Update with new PR
parents
54999199
e5bc4354
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
219 additions
and
62 deletions
+219
-62
lm_eval/base.py
lm_eval/base.py
+143
-15
lm_eval/evaluator.py
lm_eval/evaluator.py
+50
-17
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+6
-13
main.py
main.py
+20
-17
No files found.
lm_eval/base.py
View file @
025547c9
...
@@ -654,11 +654,21 @@ class PromptSourceTask(Task):
...
@@ -654,11 +654,21 @@ class PromptSourceTask(Task):
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
"""
"""
CONFIGURED_PS_METRICS
=
set
([
"Accuracy"
,
"BLEU"
,
"ROUGE"
])
CONFIGURED_RANKED_CHOICE_PS_METRICS
=
set
([
"Accuracy"
])
CONFIGURED_GENERATION_PS_METRICS
=
set
([
"BLEU"
,
"ROUGE"
])
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
SPLIT
=
None
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
,
save_examples
=
True
,
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
save_examples
=
save_examples
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""Denote where the generation should end.
"""Denote where the generation should end.
...
@@ -752,12 +762,11 @@ class PromptSourceTask(Task):
...
@@ -752,12 +762,11 @@ class PromptSourceTask(Task):
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
metric
in
self
.
CONFIGURED_
RANKED_CHOICE_
PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
pred
==
target
out
[
"acc"
]
=
pred
==
target
# TODO: Add metrics here.
# TODO: Add metrics here.
return
out
else
:
else
:
# If not, then this is a generation prompt.
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
# NOTE: In the future, target will be a list of strings.
...
@@ -765,11 +774,11 @@ class PromptSourceTask(Task):
...
@@ -765,11 +774,11 @@ class PromptSourceTask(Task):
out
=
{}
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
metric
in
self
.
CONFIGURED_
GENERATION_
PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"BLEU"
:
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
(
target
,
pred
)
out
[
"bleu"
]
=
(
target
,
pred
)
if
metric
==
"ROUGE"
:
el
if
metric
==
"ROUGE"
:
# TODO: This computes all rouge sub-metrics. Find a generic
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
# compute.
...
@@ -778,15 +787,21 @@ class PromptSourceTask(Task):
...
@@ -778,15 +787,21 @@ class PromptSourceTask(Task):
rouge_scores
=
utils
.
flatten
(
rouge_scores
)
rouge_scores
=
utils
.
flatten
(
rouge_scores
)
# Merge all the rouge-type scores into the `out` dict.
# Merge all the rouge-type scores into the `out` dict.
out
=
{
**
out
,
**
rouge_scores
}
out
=
{
**
out
,
**
rouge_scores
}
print
(
out
)
return
out
# TODO: Wrap process results s.t. override impl do not
# override the save examples.
if
self
.
save_examples
:
example
=
{
"pred"
:
pred
,
"target"
:
target
,
"answer_choices_list"
:
answer_choices_list
,
}
return
out
,
example
return
out
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
out
=
{}
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
True
out
[
"acc"
]
=
True
if
metric
==
"BLEU"
:
if
metric
==
"BLEU"
:
...
@@ -813,9 +828,6 @@ class PromptSourceTask(Task):
...
@@ -813,9 +828,6 @@ class PromptSourceTask(Task):
def
aggregation
(
self
):
def
aggregation
(
self
):
out
=
{}
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
mean
out
[
"acc"
]
=
mean
if
metric
==
"BLEU"
:
if
metric
==
"BLEU"
:
...
@@ -839,6 +851,122 @@ class PromptSourceTask(Task):
...
@@ -839,6 +851,122 @@ class PromptSourceTask(Task):
out
[
"rougeLsum_fmeasure"
]
=
mean
out
[
"rougeLsum_fmeasure"
]
=
mean
return
out
return
out
def
fewshot_examples
(
self
,
k
,
rnd
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
training_docs
())
return
self
.
_get_fewshot_examples
(
self
.
_training_docs
,
k
,
rnd
)
def
_get_fewshot_examples
(
self
,
docs
,
k
,
rnd
):
fewshot_idx
=
rnd
.
sample
(
list
(
np
.
arange
(
len
(
docs
))),
k
)
return
[
docs
[
idx
]
for
idx
in
fewshot_idx
],
[
int
(
idx
)
for
idx
in
fewshot_idx
]
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
if
num_fewshot
==
0
:
labeled_examples
=
""
fewshotex
,
fewshotidx
,
fewshotsource
=
[],
[],
None
else
:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
fewshotex
,
fewshotidx
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
fewshotsource
=
"train"
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
if
self
.
has_validation_docs
():
fewshotsource
=
"val"
elif
self
.
test_docs
():
fewshotsource
=
"test"
fewshotex
,
fewshotidx
=
self
.
_get_fewshot_examples
(
self
.
_fewshot_docs
,
k
=
num_fewshot
+
1
,
rnd
=
rnd
)
fewshotex
,
fewshotidx
=
[
(
shot
,
idx
)
for
shot
,
idx
in
zip
(
fewshotex
,
fewshotidx
)
if
shot
!=
doc
]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
,
fewshotidx
=
(
fewshotex
[:
num_fewshot
],
fewshotidx
[:
num_fewshot
],
)
labeled_examples
=
(
"
\n\n
"
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
)
example
=
self
.
doc_to_text
(
doc
)
ctx
=
description
+
labeled_examples
+
example
return
(
ctx
,
{
"fewshot_idx"
:
fewshotidx
,
"fewshot_source"
:
fewshotsource
,
"fewshot_num"
:
num_fewshot
,
"ctx"
:
ctx
,
},
)
def
get_logging_info
(
self
):
return
{
"fixed_answer_choice_list"
:
self
.
prompt
.
get_fixed_answer_choices_list
(),
"dataset_path"
:
self
.
DATASET_PATH
,
"dataset_name"
:
self
.
DATASET_NAME
,
"subset"
:
self
.
SPLIT
,
"prompt_name"
:
self
.
prompt
.
get_name
(),
"prompt_id"
:
self
.
prompt
.
get_id
(),
"prompt_jinja"
:
self
.
prompt
.
jinja
,
"prompt_original_task"
:
self
.
prompt
.
metadata
.
original_task
,
# Placeholder for comment in post-processing.
"comment"
:
""
,
}
class
MultipleChoiceTask
(
Task
):
class
MultipleChoiceTask
(
Task
):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
...
...
lm_eval/evaluator.py
View file @
025547c9
...
@@ -173,10 +173,6 @@ def evaluate(
...
@@ -173,10 +173,6 @@ def evaluate(
# get lists of each type of request
# get lists of each type of request
for
task_prompt_name
,
task
in
task_dict_items
:
for
task_prompt_name
,
task
in
task_dict_items
:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions
[
task_prompt_name
]
=
task
.
VERSION
versions
[
task_prompt_name
]
=
task
.
VERSION
# default to test doc, fall back to val doc if validation unavailable
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
...
@@ -188,7 +184,7 @@ def evaluate(
...
@@ -188,7 +184,7 @@ def evaluate(
raise
RuntimeError
(
"Task has neither test_docs nor validation_docs"
)
raise
RuntimeError
(
"Task has neither test_docs nor validation_docs"
)
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs
=
list
(
task_doc_func
())
task_docs
=
list
(
enumerate
(
list
(
task_doc_func
())
))
rnd
=
random
.
Random
()
rnd
=
random
.
Random
()
rnd
.
seed
(
42
)
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
rnd
.
shuffle
(
task_docs
)
...
@@ -199,14 +195,17 @@ def evaluate(
...
@@ -199,14 +195,17 @@ def evaluate(
else
""
else
""
)
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
(
original_doc_id
,
doc
)
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)
):
if
task
.
invalid_doc_for_prompt
(
doc
):
if
task
.
invalid_doc_for_prompt
(
doc
):
continue
continue
docs
[(
task_prompt_name
,
doc_id
)]
=
doc
docs
[(
task_prompt_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
,
fewshotex_logging_info
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
reqs
=
[
reqs
]
...
@@ -215,7 +214,7 @@ def evaluate(
...
@@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin
[
req
.
request_type
].
append
(
requests_origin
[
req
.
request_type
].
append
(
(
i
,
task_prompt_name
,
doc
,
doc_id
)
(
i
,
task_prompt_name
,
doc
,
doc_id
,
fewshotex_logging_info
)
)
)
# all responses for each (task, doc)
# all responses for each (task, doc)
...
@@ -234,33 +233,57 @@ def evaluate(
...
@@ -234,33 +233,57 @@ def evaluate(
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
]
for
resp
,
(
i
,
task_prompt_name
,
doc
,
doc_id
)
in
zip
(
for
resp
,
(
i
,
task_prompt_name
,
doc
,
doc_id
,
fewshotex_logging_info
)
in
zip
(
resps
,
requests_origin
[
reqtype
]
resps
,
requests_origin
[
reqtype
]
):
):
process_res_queue
[(
task_prompt_name
,
doc_id
)].
append
((
i
,
resp
))
process_res_queue
[(
task_prompt_name
,
doc_id
)].
append
(
(
i
,
resp
,
fewshotex_logging_info
)
)
vals
=
collections
.
defaultdict
(
list
)
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
# unpack results and sort back in order and return control to Task
for
(
task_prompt_name
,
doc_id
),
requests
in
process_res_queue
.
items
():
examples
=
[]
requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
for
(
task_prompt_name
,
doc_id
),
per_doc_requests
in
process_res_queue
.
items
():
requests
=
[
x
[
1
]
for
x
in
requests
]
per_doc_requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
per_doc_results
=
[
x
[
1
]
for
x
in
per_doc_requests
]
fewshot_logging_info
=
[
x
[
2
]
for
x
in
per_doc_requests
][
0
]
task
=
task_dict
[
task_prompt_name
]
task
=
task_dict
[
task_prompt_name
]
doc
=
docs
[(
task_prompt_name
,
doc_id
)]
doc
=
docs
[(
task_prompt_name
,
doc_id
)]
metrics
=
task
.
process_results
(
doc
,
requests
)
output
=
task
.
process_results
(
doc
,
per_doc_results
)
if
task
.
save_examples
:
metrics
,
example
=
output
example
.
update
(
fewshot_logging_info
)
example
.
update
(
task
.
get_logging_info
())
examples
.
append
(
example
)
else
:
metrics
=
output
example
=
fewshot_logging_info
example
.
update
(
task
.
get_logging_info
())
examples
.
append
(
example
)
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_prompt_name
,
metric
)].
append
(
value
)
vals
[(
task_prompt_name
,
metric
)].
append
(
value
)
# aggregate results
# aggregate results
metric_results
=
[]
for
(
task_prompt_name
,
metric
),
items
in
vals
.
items
():
for
(
task_prompt_name
,
metric
),
items
in
vals
.
items
():
task_name
,
prompt_name
=
task_prompt_name
.
split
(
"+"
)
task_name
,
prompt_name
=
task_prompt_name
.
split
(
"+"
)
results
[
task_prompt_name
][
"task_name"
]
=
task_name
results
[
task_prompt_name
][
"task_name"
]
=
task_name
results
[
task_prompt_name
][
"prompt_name"
]
=
prompt_name
results
[
task_prompt_name
][
"prompt_name"
]
=
prompt_name
task
=
task_dict
[
task_prompt_name
]
task
=
task_dict
[
task_prompt_name
]
results
[
task_prompt_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_prompt_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
_metric_results
=
{
"task_name"
:
task_name
,
"prompt_name"
:
prompt_name
,
metric
:
task
.
aggregation
()[
metric
](
items
),
**
task
.
get_logging_info
(),
}
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
...
@@ -271,8 +294,18 @@ def evaluate(
...
@@ -271,8 +294,18 @@ def evaluate(
)
)
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_prompt_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_prompt_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
_metric_results
[
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
metric_results
.
append
(
_metric_results
)
return
{
# List of results that tracks the averages per model and prompt.
"results"
:
metric_results
,
"versions"
:
dict
(
versions
),
# List of all prompt x doc examples with additional information in it.
"examples"
:
examples
,
# Original results used for generating the table when running this file.
"table_results"
:
dict
(
results
),
}
def
make_table
(
result_dict
):
def
make_table
(
result_dict
):
...
@@ -293,7 +326,7 @@ def make_table(result_dict):
...
@@ -293,7 +326,7 @@ def make_table(result_dict):
]
]
values
=
[]
values
=
[]
for
k
,
dic
in
result_dict
[
"results"
].
items
():
for
k
,
dic
in
result_dict
[
"
table_
results"
].
items
():
version
=
result_dict
[
"versions"
][
k
]
version
=
result_dict
[
"versions"
][
k
]
for
m
,
v
in
dic
.
items
():
for
m
,
v
in
dic
.
items
():
if
m
.
endswith
(
"_stderr"
):
if
m
.
endswith
(
"_stderr"
):
...
...
lm_eval/tasks/coqa.py
View file @
025547c9
...
@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
...
@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
"""
"""
target
=
self
.
doc_to_target
(
doc
).
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
print
(
"*"
*
80
)
print
(
f
"DOC:
{
doc
}
"
)
# print(f"PS: {self.prompt.apply(doc)}")
print
(
f
"TEXT:
{
self
.
doc_to_text
(
doc
)
}
"
)
print
(
f
"TARGET:
{
target
}
END TARGET"
)
print
(
f
"PRED:
{
pred
}
END PRED"
)
print
(
"*"
*
80
)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
out
=
{
"f1"
:
scores
[
"f1"
],
"f1"
:
scores
[
"f1"
],
"em"
:
scores
[
"em"
],
"em"
:
scores
[
"em"
],
}
}
if
self
.
save_examples
:
example
=
{
"target"
:
target
,
"pred"
:
pred
}
return
out
,
example
return
out
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"f1"
:
True
,
"f1"
:
True
,
...
...
main.py
View file @
025547c9
...
@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
...
@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--model
'
,
required
=
True
)
parser
.
add_argument
(
"
--model
"
,
required
=
True
)
parser
.
add_argument
(
'
--model_args
'
,
default
=
""
)
parser
.
add_argument
(
"
--model_args
"
,
default
=
""
)
parser
.
add_argument
(
'
--tasks
'
,
default
=
"all_tasks"
)
parser
.
add_argument
(
"
--tasks
"
,
default
=
"all_tasks"
)
parser
.
add_argument
(
'
--provide_description
'
,
action
=
"store_true"
)
parser
.
add_argument
(
"
--provide_description
"
,
action
=
"store_true"
)
parser
.
add_argument
(
'
--num_fewshot
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--num_fewshot
"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'
--batch_size
'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"
--batch_size
"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'
--device
'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"
--device
"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'
--output_path
'
,
default
=
None
)
parser
.
add_argument
(
"
--output_path
"
,
default
=
None
)
parser
.
add_argument
(
'
--limit
'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"
--limit
"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'
--no_cache
'
,
action
=
"store_true"
)
parser
.
add_argument
(
"
--no_cache
"
,
action
=
"store_true"
)
parser
.
add_argument
(
'
--description_dict_path
'
,
default
=
None
)
parser
.
add_argument
(
"
--description_dict_path
"
,
default
=
None
)
parser
.
add_argument
(
'
--check_integrity
'
,
action
=
"store_true"
)
parser
.
add_argument
(
"
--check_integrity
"
,
action
=
"store_true"
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
assert
not
args
.
provide_description
# not implemented
assert
not
args
.
provide_description
# not implemented
if
args
.
limit
:
if
args
.
limit
:
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
args
.
tasks
==
"all_tasks"
:
if
args
.
tasks
==
"all_tasks"
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
tasks
.
ALL_TASKS
...
@@ -38,7 +40,7 @@ def main():
...
@@ -38,7 +40,7 @@ def main():
description_dict
=
{}
description_dict
=
{}
if
args
.
description_dict_path
:
if
args
.
description_dict_path
:
with
open
(
args
.
description_dict_path
,
'r'
)
as
f
:
with
open
(
args
.
description_dict_path
,
"r"
)
as
f
:
description_dict
=
json
.
load
(
f
)
description_dict
=
json
.
load
(
f
)
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
...
@@ -51,11 +53,12 @@ def main():
...
@@ -51,11 +53,12 @@ def main():
no_cache
=
args
.
no_cache
,
no_cache
=
args
.
no_cache
,
limit
=
args
.
limit
,
limit
=
args
.
limit
,
description_dict
=
description_dict
,
description_dict
=
description_dict
,
check_integrity
=
args
.
check_integrity
check_integrity
=
args
.
check_integrity
,
)
)
print
(
results
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
print
(
dumped
)
print
(
dumped
)
if
args
.
output_path
:
if
args
.
output_path
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment