Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
b39f5015
Unverified
Commit
b39f5015
authored
Apr 09, 2024
by
Fengzhe Zhou
Committed by
GitHub
Apr 09, 2024
Browse files
[Sync] update taco (#1030)
parent
16f29b25
Changes
87
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
107 additions
and
437 deletions
+107
-437
opencompass/utils/prompt.py
opencompass/utils/prompt.py
+5
-5
opencompass/utils/run.py
opencompass/utils/run.py
+14
-3
opencompass/utils/text_postprocessors.py
opencompass/utils/text_postprocessors.py
+5
-5
requirements/runtime.txt
requirements/runtime.txt
+1
-0
run.py
run.py
+1
-361
setup.py
setup.py
+40
-33
tools/prediction_merger.py
tools/prediction_merger.py
+41
-30
No files found.
opencompass/utils/prompt.py
View file @
b39f5015
...
...
@@ -73,7 +73,7 @@ class PromptList(list):
Args:
src (str): The string to be replaced.
dst (
str or
Prompt
List
): The string or PromptList to replace with.
dst (Prompt
Type
): The string or PromptList to replace with.
Returns:
PromptList: A new PromptList with 'src' replaced by 'dst'.
...
...
@@ -98,7 +98,7 @@ class PromptList(list):
Args:
src (str): The string to be replaced.
dst (
str or
Prompt
List
): The string or PromptList to replace with.
dst (Prompt
Type
): The string or PromptList to replace with.
Returns:
PromptList: A new PromptList with 'src' replaced by 'dst'.
...
...
@@ -139,7 +139,7 @@ class PromptList(list):
"""Adds a string or another PromptList to this PromptList.
Args:
other (
str or
Prompt
List
): The string or PromptList to be added.
other (Prompt
Type
): The string or PromptList to be added.
Returns:
PromptList: A new PromptList that is the result of the addition.
...
...
@@ -156,7 +156,7 @@ class PromptList(list):
'+' operator.
Args:
other (
str or
Prompt
List
): The string or PromptList to be added.
other (Prompt
Type
): The string or PromptList to be added.
Returns:
PromptList: A new PromptList that is the result of the addition.
...
...
@@ -172,7 +172,7 @@ class PromptList(list):
"""Implements in-place addition for the PromptList.
Args:
other (
str or
Prompt
List
): The string or PromptList to be added.
other (Prompt
Type
): The string or PromptList to be added.
Returns:
PromptList: The updated PromptList.
...
...
opencompass/utils/run.py
View file @
b39f5015
...
...
@@ -48,6 +48,19 @@ def match_cfg_file(workdir: str, pattern: Union[str, List[str]]) -> List[str]:
return
files
def
try_fill_in_custom_cfgs
(
config
):
for
i
,
dataset
in
enumerate
(
config
[
'datasets'
]):
if
'type'
not
in
dataset
:
config
[
'datasets'
][
i
]
=
make_custom_dataset_config
(
dataset
)
if
'model_dataset_combinations'
not
in
config
:
return
config
for
mdc
in
config
[
'model_dataset_combinations'
]:
for
i
,
dataset
in
enumerate
(
mdc
[
'datasets'
]):
if
'type'
not
in
dataset
:
mdc
[
'datasets'
][
i
]
=
make_custom_dataset_config
(
dataset
)
return
config
def
get_config_from_arg
(
args
)
->
Config
:
"""Get the config object given args.
...
...
@@ -58,9 +71,7 @@ def get_config_from_arg(args) -> Config:
"""
if
args
.
config
:
config
=
Config
.
fromfile
(
args
.
config
,
format_python_code
=
False
)
for
i
,
dataset
in
enumerate
(
config
[
'datasets'
]):
if
'type'
not
in
dataset
:
config
[
'datasets'
][
i
]
=
make_custom_dataset_config
(
dataset
)
config
=
try_fill_in_custom_cfgs
(
config
)
return
config
# parse dataset args
if
not
args
.
datasets
and
not
args
.
custom_dataset_path
:
...
...
opencompass/utils/text_postprocessors.py
View file @
b39f5015
...
...
@@ -94,11 +94,11 @@ def first_option_postprocess(text: str, options: str, cushion=True) -> str:
f
'答案是\s?(\S+)(?:。|$)'
,
f
'答案应该是\s?(\S+)(?:。|$)'
,
f
'答案为\s?(\S+)(?:。|$)'
,
f
'[Tt]he answer is ([
{
options
}
])'
,
f
'[Tt]he answer is option ([
{
options
}
])'
,
f
'[Tt]he correct answer is ([
{
options
}
])'
,
f
'[Tt]he correct answer is option ([
{
options
}
])'
,
f
'[Tt]he answer to the question is ([
{
options
}
])'
,
f
'[Tt]he answer is
\(?
([
{
options
}
])
\)?
'
,
f
'[Tt]he answer is option
\(?
([
{
options
}
])
\)?
'
,
f
'[Tt]he correct answer is
\(?
([
{
options
}
])
\)?
'
,
f
'[Tt]he correct answer is option
\(?
([
{
options
}
])
\)?
'
,
f
'[Tt]he answer to the question is
\(?
([
{
options
}
])
\)?
'
,
f
'^选项\s?([
{
options
}
])'
,
f
'^([
{
options
}
])\s?选?项'
,
f
'(\s|^)[
{
options
}
][\s。,,::\.$]'
,
...
...
requirements/runtime.txt
View file @
b39f5015
...
...
@@ -21,6 +21,7 @@ OpenCC
opencv-python-headless
pandas<2.0.0
prettytable
pyext
pypinyin
python-Levenshtein
rank_bm25==0.2.2
...
...
run.py
View file @
b39f5015
import
argparse
import
getpass
import
os
import
os.path
as
osp
from
datetime
import
datetime
from
mmengine.config
import
Config
,
DictAction
from
opencompass.partitioners
import
MultimodalNaivePartitioner
from
opencompass.registry
import
PARTITIONERS
,
RUNNERS
,
build_from_cfg
from
opencompass.runners
import
SlurmRunner
from
opencompass.summarizers
import
DefaultSummarizer
from
opencompass.utils
import
LarkReporter
,
get_logger
from
opencompass.utils.run
import
(
exec_mm_infer_runner
,
fill_eval_cfg
,
fill_infer_cfg
,
get_config_from_arg
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Run an evaluation task'
)
parser
.
add_argument
(
'config'
,
nargs
=
'?'
,
help
=
'Train config file path'
)
# add mutually exclusive args `--slurm` `--dlc`, defaults to local runner
# if "infer" or "eval" not specified
launch_method
=
parser
.
add_mutually_exclusive_group
()
launch_method
.
add_argument
(
'--slurm'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to force tasks to run with srun. '
'If True, `--partition(-p)` must be set. '
'Defaults to False'
)
launch_method
.
add_argument
(
'--dlc'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to force tasks to run on dlc. If '
'True, `--aliyun-cfg` must be set. Defaults'
' to False'
)
# multi-modal support
parser
.
add_argument
(
'--mm-eval'
,
help
=
'Whether or not enable multimodal evaluation'
,
action
=
'store_true'
,
default
=
False
)
# Add shortcut parameters (models, datasets and summarizer)
parser
.
add_argument
(
'--models'
,
nargs
=
'+'
,
help
=
''
,
default
=
None
)
parser
.
add_argument
(
'--datasets'
,
nargs
=
'+'
,
help
=
''
,
default
=
None
)
parser
.
add_argument
(
'--summarizer'
,
help
=
''
,
default
=
None
)
# add general args
parser
.
add_argument
(
'--debug'
,
help
=
'Debug mode, in which scheduler will run tasks '
'in the single process, and output will not be '
'redirected to files'
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
'--dry-run'
,
help
=
'Dry run mode, in which the scheduler will not '
'actually run the tasks, but only print the commands '
'to run'
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
'-m'
,
'--mode'
,
help
=
'Running mode. You can choose "infer" if you '
'only want the inference results, or "eval" if you '
'already have the results and want to evaluate them, '
'or "viz" if you want to visualize the results.'
,
choices
=
[
'all'
,
'infer'
,
'eval'
,
'viz'
],
default
=
'all'
,
type
=
str
)
parser
.
add_argument
(
'-r'
,
'--reuse'
,
nargs
=
'?'
,
type
=
str
,
const
=
'latest'
,
help
=
'Reuse previous outputs & results, and run any '
'missing jobs presented in the config. If its '
'argument is not specified, the latest results in '
'the work_dir will be reused. The argument should '
'also be a specific timestamp, e.g. 20230516_144254'
),
parser
.
add_argument
(
'-w'
,
'--work-dir'
,
help
=
'Work path, all the outputs will be '
'saved in this path, including the slurm logs, '
'the evaluation results, the summary results, etc.'
'If not specified, the work_dir will be set to '
'./outputs/default.'
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--config-dir'
,
default
=
'configs'
,
help
=
'Use the custom config directory instead of config/ to '
'search the configs for datasets, models and summarizers'
,
type
=
str
)
parser
.
add_argument
(
'-l'
,
'--lark'
,
help
=
'Report the running status to lark bot'
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
'--max-partition-size'
,
help
=
'The maximum size of an infer task. Only '
'effective when "infer" is missing from the config.'
,
type
=
int
,
default
=
40000
),
parser
.
add_argument
(
'--gen-task-coef'
,
help
=
'The dataset cost measurement coefficient for generation tasks, '
'Only effective when "infer" is missing from the config.'
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
'--max-num-workers'
,
help
=
'Max number of workers to run in parallel. '
'Will be overrideen by the "max_num_workers" argument '
'in the config.'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--max-workers-per-gpu'
,
help
=
'Max task to run in parallel on one GPU. '
'It will only be used in the local runner.'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--retry'
,
help
=
'Number of retries if the job failed when using slurm or dlc. '
'Will be overrideen by the "retry" argument in the config.'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--dump-eval-details'
,
help
=
'Whether to dump the evaluation details, including the '
'correctness of each sample, bpb, etc.'
,
action
=
'store_true'
,
)
# set srun args
slurm_parser
=
parser
.
add_argument_group
(
'slurm_args'
)
parse_slurm_args
(
slurm_parser
)
# set dlc args
dlc_parser
=
parser
.
add_argument_group
(
'dlc_args'
)
parse_dlc_args
(
dlc_parser
)
# set hf args
hf_parser
=
parser
.
add_argument_group
(
'hf_args'
)
parse_hf_args
(
hf_parser
)
# set custom dataset args
custom_dataset_parser
=
parser
.
add_argument_group
(
'custom_dataset_args'
)
parse_custom_dataset_args
(
custom_dataset_parser
)
args
=
parser
.
parse_args
()
if
args
.
slurm
:
assert
args
.
partition
is
not
None
,
(
'--partition(-p) must be set if you want to use slurm'
)
if
args
.
dlc
:
assert
os
.
path
.
exists
(
args
.
aliyun_cfg
),
(
'When launching tasks using dlc, it needs to be configured '
'in "~/.aliyun.cfg", or use "--aliyun-cfg $ALiYun-CFG_Path"'
' to specify a new path.'
)
return
args
def
parse_slurm_args
(
slurm_parser
):
"""These args are all for slurm launch."""
slurm_parser
.
add_argument
(
'-p'
,
'--partition'
,
help
=
'Slurm partition name'
,
default
=
None
,
type
=
str
)
slurm_parser
.
add_argument
(
'-q'
,
'--quotatype'
,
help
=
'Slurm quota type'
,
default
=
None
,
type
=
str
)
slurm_parser
.
add_argument
(
'--qos'
,
help
=
'Slurm quality of service'
,
default
=
None
,
type
=
str
)
def
parse_dlc_args
(
dlc_parser
):
"""These args are all for dlc launch."""
dlc_parser
.
add_argument
(
'--aliyun-cfg'
,
help
=
'The config path for aliyun config'
,
default
=
'~/.aliyun.cfg'
,
type
=
str
)
def
parse_hf_args
(
hf_parser
):
"""These args are all for the quick construction of HuggingFace models."""
hf_parser
.
add_argument
(
'--hf-path'
,
type
=
str
)
hf_parser
.
add_argument
(
'--peft-path'
,
type
=
str
)
hf_parser
.
add_argument
(
'--tokenizer-path'
,
type
=
str
)
hf_parser
.
add_argument
(
'--model-kwargs'
,
nargs
=
'+'
,
action
=
DictAction
,
default
=
{})
hf_parser
.
add_argument
(
'--tokenizer-kwargs'
,
nargs
=
'+'
,
action
=
DictAction
,
default
=
{})
hf_parser
.
add_argument
(
'--max-out-len'
,
type
=
int
)
hf_parser
.
add_argument
(
'--max-seq-len'
,
type
=
int
)
hf_parser
.
add_argument
(
'--no-batch-padding'
,
action
=
'store_true'
,
default
=
False
)
hf_parser
.
add_argument
(
'--batch-size'
,
type
=
int
)
hf_parser
.
add_argument
(
'--num-gpus'
,
type
=
int
)
hf_parser
.
add_argument
(
'--pad-token-id'
,
type
=
int
)
def
parse_custom_dataset_args
(
custom_dataset_parser
):
"""These args are all for the quick construction of custom datasets."""
custom_dataset_parser
.
add_argument
(
'--custom-dataset-path'
,
type
=
str
)
custom_dataset_parser
.
add_argument
(
'--custom-dataset-meta-path'
,
type
=
str
)
custom_dataset_parser
.
add_argument
(
'--custom-dataset-data-type'
,
type
=
str
,
choices
=
[
'mcq'
,
'qa'
])
custom_dataset_parser
.
add_argument
(
'--custom-dataset-infer-method'
,
type
=
str
,
choices
=
[
'gen'
,
'ppl'
])
def
main
():
args
=
parse_args
()
if
args
.
dry_run
:
args
.
debug
=
True
# initialize logger
logger
=
get_logger
(
log_level
=
'DEBUG'
if
args
.
debug
else
'INFO'
)
cfg
=
get_config_from_arg
(
args
)
if
args
.
work_dir
is
not
None
:
cfg
[
'work_dir'
]
=
args
.
work_dir
else
:
cfg
.
setdefault
(
'work_dir'
,
'./outputs/default/'
)
# cfg_time_str defaults to the current time
cfg_time_str
=
dir_time_str
=
datetime
.
now
().
strftime
(
'%Y%m%d_%H%M%S'
)
if
args
.
reuse
:
if
args
.
reuse
==
'latest'
:
if
not
os
.
path
.
exists
(
cfg
.
work_dir
)
or
not
os
.
listdir
(
cfg
.
work_dir
):
logger
.
warning
(
'No previous results to reuse!'
)
else
:
dirs
=
os
.
listdir
(
cfg
.
work_dir
)
dir_time_str
=
sorted
(
dirs
)[
-
1
]
else
:
dir_time_str
=
args
.
reuse
logger
.
info
(
f
'Reusing experiements from
{
dir_time_str
}
'
)
elif
args
.
mode
in
[
'eval'
,
'viz'
]:
raise
ValueError
(
'You must specify -r or --reuse when running in eval '
'or viz mode!'
)
# update "actual" work_dir
cfg
[
'work_dir'
]
=
osp
.
join
(
cfg
.
work_dir
,
dir_time_str
)
os
.
makedirs
(
osp
.
join
(
cfg
.
work_dir
,
'configs'
),
exist_ok
=
True
)
# dump config
output_config_path
=
osp
.
join
(
cfg
.
work_dir
,
'configs'
,
f
'
{
cfg_time_str
}
.py'
)
cfg
.
dump
(
output_config_path
)
# Config is intentally reloaded here to avoid initialized
# types cannot be serialized
cfg
=
Config
.
fromfile
(
output_config_path
,
format_python_code
=
False
)
# report to lark bot if specify --lark
if
not
args
.
lark
:
cfg
[
'lark_bot_url'
]
=
None
elif
cfg
.
get
(
'lark_bot_url'
,
None
):
content
=
f
'
{
getpass
.
getuser
()
}
\'
s task has been launched!'
LarkReporter
(
cfg
[
'lark_bot_url'
]).
post
(
content
)
if
args
.
mode
in
[
'all'
,
'infer'
]:
# When user have specified --slurm or --dlc, or have not set
# "infer" in config, we will provide a default configuration
# for infer
if
(
args
.
dlc
or
args
.
slurm
)
and
cfg
.
get
(
'infer'
,
None
):
logger
.
warning
(
'You have set "infer" in the config, but '
'also specified --slurm or --dlc. '
'The "infer" configuration will be overridden by '
'your runtime arguments.'
)
# Check whether run multimodal evaluation
if
args
.
mm_eval
:
partitioner
=
MultimodalNaivePartitioner
(
osp
.
join
(
cfg
[
'work_dir'
],
'predictions/'
))
tasks
=
partitioner
(
cfg
)
exec_mm_infer_runner
(
tasks
,
args
,
cfg
)
return
if
args
.
dlc
or
args
.
slurm
or
cfg
.
get
(
'infer'
,
None
)
is
None
:
fill_infer_cfg
(
cfg
,
args
)
if
args
.
partition
is
not
None
:
if
RUNNERS
.
get
(
cfg
.
infer
.
runner
.
type
)
==
SlurmRunner
:
cfg
.
infer
.
runner
.
partition
=
args
.
partition
cfg
.
infer
.
runner
.
quotatype
=
args
.
quotatype
else
:
logger
.
warning
(
'SlurmRunner is not used, so the partition '
'argument is ignored.'
)
if
args
.
debug
:
cfg
.
infer
.
runner
.
debug
=
True
if
args
.
lark
:
cfg
.
infer
.
runner
.
lark_bot_url
=
cfg
[
'lark_bot_url'
]
cfg
.
infer
.
partitioner
[
'out_dir'
]
=
osp
.
join
(
cfg
[
'work_dir'
],
'predictions/'
)
partitioner
=
PARTITIONERS
.
build
(
cfg
.
infer
.
partitioner
)
tasks
=
partitioner
(
cfg
)
if
args
.
dry_run
:
return
runner
=
RUNNERS
.
build
(
cfg
.
infer
.
runner
)
# Add extra attack config if exists
if
hasattr
(
cfg
,
'attack'
):
for
task
in
tasks
:
cfg
.
attack
.
dataset
=
task
.
datasets
[
0
][
0
].
abbr
task
.
attack
=
cfg
.
attack
runner
(
tasks
)
# evaluate
if
args
.
mode
in
[
'all'
,
'eval'
]:
# When user have specified --slurm or --dlc, or have not set
# "eval" in config, we will provide a default configuration
# for eval
if
(
args
.
dlc
or
args
.
slurm
)
and
cfg
.
get
(
'eval'
,
None
):
logger
.
warning
(
'You have set "eval" in the config, but '
'also specified --slurm or --dlc. '
'The "eval" configuration will be overridden by '
'your runtime arguments.'
)
if
args
.
dlc
or
args
.
slurm
or
cfg
.
get
(
'eval'
,
None
)
is
None
:
fill_eval_cfg
(
cfg
,
args
)
if
args
.
dump_eval_details
:
cfg
.
eval
.
runner
.
task
.
dump_details
=
True
if
args
.
partition
is
not
None
:
if
RUNNERS
.
get
(
cfg
.
eval
.
runner
.
type
)
==
SlurmRunner
:
cfg
.
eval
.
runner
.
partition
=
args
.
partition
cfg
.
eval
.
runner
.
quotatype
=
args
.
quotatype
else
:
logger
.
warning
(
'SlurmRunner is not used, so the partition '
'argument is ignored.'
)
if
args
.
debug
:
cfg
.
eval
.
runner
.
debug
=
True
if
args
.
lark
:
cfg
.
eval
.
runner
.
lark_bot_url
=
cfg
[
'lark_bot_url'
]
cfg
.
eval
.
partitioner
[
'out_dir'
]
=
osp
.
join
(
cfg
[
'work_dir'
],
'results/'
)
partitioner
=
PARTITIONERS
.
build
(
cfg
.
eval
.
partitioner
)
tasks
=
partitioner
(
cfg
)
if
args
.
dry_run
:
return
runner
=
RUNNERS
.
build
(
cfg
.
eval
.
runner
)
# For meta-review-judge in subjective evaluation
if
isinstance
(
tasks
,
list
)
and
len
(
tasks
)
!=
0
and
isinstance
(
tasks
[
0
],
list
):
for
task_part
in
tasks
:
runner
(
task_part
)
else
:
runner
(
tasks
)
# visualize
if
args
.
mode
in
[
'all'
,
'eval'
,
'viz'
]:
summarizer_cfg
=
cfg
.
get
(
'summarizer'
,
{})
if
not
summarizer_cfg
or
summarizer_cfg
.
get
(
'type'
,
None
)
is
None
:
summarizer_cfg
[
'type'
]
=
DefaultSummarizer
summarizer_cfg
[
'config'
]
=
cfg
summarizer
=
build_from_cfg
(
summarizer_cfg
)
summarizer
.
summarize
(
time_str
=
cfg_time_str
)
from
opencompass.cli.main
import
main
if
__name__
==
'__main__'
:
main
()
setup.py
View file @
b39f5015
...
...
@@ -103,7 +103,8 @@ def get_version():
def
do_setup
():
setup
(
name
=
'opencompass'
,
setup
(
name
=
'opencompass'
,
author
=
'OpenCompass Contributors'
,
version
=
get_version
(),
description
=
'A comprehensive toolkit for large model evaluation'
,
...
...
@@ -114,7 +115,7 @@ def do_setup():
cmdclass
=
{
'download_nltk'
:
DownloadNLTK
},
setup_requires
=
[
'nltk==3.8'
],
python_requires
=
'>=3.8.0'
,
install_requires
=
parse_requirements
(
'requirements/runtime.txt'
),
#
install_requires=parse_requirements('requirements/runtime.txt'),
license
=
'Apache License 2.0'
,
packages
=
find_packages
(
exclude
=
[
'test*'
,
...
...
@@ -135,7 +136,13 @@ def do_setup():
'Intended Audience :: Developers'
,
'Intended Audience :: Education'
,
'Intended Audience :: Science/Research'
,
])
],
entry_points
=
{
'console_scripts'
:
[
'opencompass = opencompass.cli.main:main'
,
],
},
)
if
__name__
==
'__main__'
:
...
...
tools/prediction_merger.py
View file @
b39f5015
import
argparse
import
copy
import
json
import
os
.path
as
osp
import
os
import
mmengine
from
mmengine.config
import
Config
,
ConfigDict
...
...
@@ -13,24 +13,16 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
(
description
=
'Merge patitioned predictions'
)
parser
.
add_argument
(
'config'
,
help
=
'Train config file path'
)
parser
.
add_argument
(
'-w'
,
'--work-dir'
,
help
=
'Work path, all the outputs will be '
'saved in this path, including the slurm logs, '
'the evaluation results, the summary results, etc.'
'If not specified, the work_dir will be set to '
'./outputs/default.'
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'-w'
,
'--work-dir'
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'-r'
,
'--reuse'
,
default
=
'latest'
,
type
=
str
)
parser
.
add_argument
(
'-c'
,
'--clean'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
return
args
class
PredictionMerger
:
""""""
def
__init__
(
self
,
cfg
:
ConfigDict
)
->
None
:
self
.
cfg
=
cfg
self
.
model_cfg
=
copy
.
deepcopy
(
self
.
cfg
[
'model'
])
self
.
dataset_cfg
=
copy
.
deepcopy
(
self
.
cfg
[
'dataset'
])
...
...
@@ -39,26 +31,23 @@ class PredictionMerger:
def
run
(
self
):
filename
=
get_infer_output_path
(
self
.
model_cfg
,
self
.
dataset_cfg
,
os
p
.
join
(
self
.
work_dir
,
'predictions'
))
root
,
ext
=
os
p
.
splitext
(
filename
)
os
.
path
.
join
(
self
.
work_dir
,
'predictions'
))
root
,
ext
=
os
.
path
.
splitext
(
filename
)
partial_filename
=
root
+
'_0'
+
ext
if
os
p
.
exists
(
os
p
.
realpath
(
filename
)):
if
os
.
path
.
exists
(
os
.
path
.
realpath
(
filename
)):
return
if
not
os
p
.
exists
(
os
p
.
realpath
(
partial_filename
)):
if
not
os
.
path
.
exists
(
os
.
path
.
realpath
(
partial_filename
)):
print
(
f
'
{
filename
}
not found'
)
return
# Load predictions
partial_filenames
=
[]
if
osp
.
exists
(
osp
.
realpath
(
filename
)):
preds
=
mmengine
.
load
(
filename
)
else
:
preds
,
offset
=
{},
0
i
=
1
while
os
p
.
exists
(
os
p
.
realpath
(
partial_filename
)):
partial_filenames
.
append
(
os
p
.
realpath
(
partial_filename
))
while
os
.
path
.
exists
(
os
.
path
.
realpath
(
partial_filename
)):
partial_filenames
.
append
(
os
.
path
.
realpath
(
partial_filename
))
_preds
=
mmengine
.
load
(
partial_filename
)
partial_filename
=
root
+
f
'_
{
i
}
'
+
ext
i
+=
1
...
...
@@ -75,6 +64,11 @@ class PredictionMerger:
with
open
(
filename
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
preds
,
f
,
indent
=
4
,
ensure_ascii
=
False
)
if
self
.
cfg
[
'clean'
]:
for
partial_filename
in
partial_filenames
:
print
(
f
'Remove
{
partial_filename
}
'
)
os
.
remove
(
partial_filename
)
def
dispatch_tasks
(
cfg
):
for
model
in
cfg
[
'models'
]:
...
...
@@ -82,7 +76,8 @@ def dispatch_tasks(cfg):
PredictionMerger
({
'model'
:
model
,
'dataset'
:
dataset
,
'work_dir'
:
cfg
[
'work_dir'
]
'work_dir'
:
cfg
[
'work_dir'
],
'clean'
:
cfg
[
'clean'
]
}).
run
()
...
...
@@ -94,6 +89,22 @@ def main():
cfg
[
'work_dir'
]
=
args
.
work_dir
else
:
cfg
.
setdefault
(
'work_dir'
,
'./outputs/default'
)
if
args
.
reuse
:
if
args
.
reuse
==
'latest'
:
if
not
os
.
path
.
exists
(
cfg
.
work_dir
)
or
not
os
.
listdir
(
cfg
.
work_dir
):
print
(
'No previous results to reuse!'
)
return
else
:
dirs
=
os
.
listdir
(
cfg
.
work_dir
)
dir_time_str
=
sorted
(
dirs
)[
-
1
]
else
:
dir_time_str
=
args
.
reuse
cfg
[
'work_dir'
]
=
os
.
path
.
join
(
cfg
.
work_dir
,
dir_time_str
)
cfg
[
'clean'
]
=
args
.
clean
dispatch_tasks
(
cfg
)
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment