Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
dbb20b82
Unverified
Commit
dbb20b82
authored
Oct 27, 2023
by
Fengzhe Zhou
Committed by
GitHub
Oct 27, 2023
Browse files
[Sync] update (#517)
parent
6f07af30
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
134 additions
and
5 deletions
+134
-5
opencompass/partitioners/sub_naive.py
opencompass/partitioners/sub_naive.py
+1
-1
opencompass/summarizers/default.py
opencompass/summarizers/default.py
+1
-0
opencompass/tasks/openicl_eval.py
opencompass/tasks/openicl_eval.py
+123
-4
requirements/runtime.txt
requirements/runtime.txt
+1
-0
run.py
run.py
+8
-0
No files found.
opencompass/partitioners/sub_naive.py
View file @
dbb20b82
...
...
@@ -23,7 +23,7 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
mode
:
str
,
out_dir
:
str
,
model_pairs
:
Optional
[
List
[
Tuple
]]
=
None
,
keep_keys
:
List
[
str
]
=
[
'eval.runner.task.judge_cfg'
]
):
keep_keys
:
Optional
[
List
[
str
]
]
=
None
):
super
().
__init__
(
out_dir
=
out_dir
,
keep_keys
=
keep_keys
)
assert
mode
in
[
'all'
,
'one_to_n'
,
'fixed'
]
self
.
mode
=
mode
...
...
opencompass/summarizers/default.py
View file @
dbb20b82
...
...
@@ -72,6 +72,7 @@ class DefaultSummarizer:
if
not
osp
.
exists
(
filepath
):
continue
result
=
mmengine
.
load
(
filepath
)
result
.
pop
(
'details'
,
None
)
raw_results
[
model_abbr
][
dataset_abbr
]
=
result
if
'error'
in
result
:
self
.
logger
.
debug
(
f
'error in
{
model_abbr
}
{
dataset_abbr
}
{
result
[
"error"
]
}
'
)
...
...
opencompass/tasks/openicl_eval.py
View file @
dbb20b82
import
argparse
import
copy
import
fnmatch
import
math
import
os.path
as
osp
import
statistics
import
time
from
collections
import
Counter
from
inspect
import
signature
from
shutil
import
which
from
typing
import
Optional
from
typing
import
List
,
Optional
import
mmengine
from
mmengine.config
import
Config
,
ConfigDict
...
...
@@ -35,6 +38,8 @@ class OpenICLEvalTask(BaseTask):
super
().
__init__
(
cfg
)
self
.
num_gpus
=
0
self
.
logger
=
get_logger
()
self
.
dump_details
=
cfg
.
get
(
'eval'
,
{}).
get
(
'runner'
,
{}).
get
(
'task'
,
{}).
get
(
'dump_details'
,
False
)
def
get_command
(
self
,
cfg_path
,
template
):
script_path
=
__file__
...
...
@@ -113,7 +118,7 @@ class OpenICLEvalTask(BaseTask):
[
sub_preds
[
str
(
i
)]
for
i
in
range
(
len
(
sub_preds
))])
filename
=
root
+
f
'_
{
i
}
'
+
ext
i
+=
1
pred_dicts
=
copy
.
deepcopy
(
preds
)
preds
=
{
k
:
[
pred
.
get
(
k
)
for
pred
in
preds
]
for
k
in
preds
[
0
]}
pred_strs
=
preds
.
pop
(
'prediction'
)
...
...
@@ -163,6 +168,7 @@ class OpenICLEvalTask(BaseTask):
]
icl_evaluator
=
ICL_EVALUATORS
.
build
(
self
.
eval_cfg
[
'evaluator'
])
preds
[
'predictions'
]
=
pred_strs
preds
[
'references'
]
=
(
test_set
[
self
.
output_column
]
if
self
.
output_column
else
None
)
...
...
@@ -172,18 +178,42 @@ class OpenICLEvalTask(BaseTask):
}
result
=
icl_evaluator
.
score
(
**
preds
)
if
self
.
dump_details
:
try
:
details
=
result
.
pop
(
'details'
,
None
)
result
[
'details'
]
=
self
.
format_details
(
pred_strs
,
test_set
[
self
.
output_column
],
details
,
pred_dicts
)
result
[
'type'
]
=
result
[
'details'
].
pop
(
'type'
,
None
)
if
'PPL'
in
str
(
self
.
dataset_cfg
.
infer_cfg
.
inferencer
.
type
):
result
[
'correct_bpb'
],
result
[
'incorrect_bpb'
]
=
self
.
calculate_bpb
(
pred_dicts
)
else
:
result
[
'incorrect_bpb'
]
=
result
[
'correct_bpb'
]
=
-
1
except
Exception
:
result
[
'incorrect_bpb'
]
=
result
[
'correct_bpb'
]
=
-
1
else
:
result
.
pop
(
'details'
,
None
)
if
'error'
in
result
:
self
.
logger
.
error
(
f
'Task
{
task_abbr_from_cfg
(
self
.
cfg
)
}
:
{
result
[
"error"
]
}
'
)
return
else
:
self
.
logger
.
info
(
f
'Task
{
task_abbr_from_cfg
(
self
.
cfg
)
}
:
{
result
}
'
)
result_wo_details
=
{
i
:
result
[
i
]
for
i
in
result
if
i
!=
'details'
}
self
.
logger
.
info
(
f
'Task
{
task_abbr_from_cfg
(
self
.
cfg
)
}
:
{
result_wo_details
}
'
)
# Save result
out_path
=
get_infer_output_path
(
self
.
model_cfg
,
self
.
dataset_cfg
,
osp
.
join
(
self
.
work_dir
,
'results'
))
mkdir_or_exist
(
osp
.
split
(
out_path
)[
0
])
mmengine
.
dump
(
result
,
out_path
)
mmengine
.
dump
(
result
,
out_path
,
ensure_ascii
=
False
,
indent
=
4
)
def
_extract_role_pred
(
self
,
s
:
str
,
begin_str
:
Optional
[
str
],
end_str
:
Optional
[
str
])
->
str
:
...
...
@@ -215,6 +245,95 @@ class OpenICLEvalTask(BaseTask):
return
s
[
start
:
end
]
def
format_details
(
self
,
predictions
,
references
,
details
,
pred_dicts
):
"""This function is responsible for formatting prediction details.
Args:
predictions (list): The prediction list.
references (list): The reference list.
details (list): Contains the 'pred' 'answer' and 'correct' for each
sample. Such as `[{'pred': '光荣和ωforce',
'answers': ['光荣和ω-force', '光荣和ωforce'], 'correct': True}]`
pred_dicts (list): Contains a list of samples with the original
prompts. Such as
`[{'origin_prompt': '根据文章回答问题。你的答案应该尽可能3》…………',
'prediction': ' 光荣和ω-force
\n
', 'gold': ['光荣和ω-force']}]`
Returns:
list: The formatted prediction details.
"""
results
=
{}
for
i
in
range
(
len
(
predictions
)):
ppl_flag
=
False
result
=
{}
origin_prediction
=
copy
.
deepcopy
(
pred_dicts
[
i
])
origin_prediction
.
pop
(
'in-context examples'
,
None
)
origin_prediction
.
pop
(
'prediction'
,
None
)
keys
=
copy
.
deepcopy
(
list
(
origin_prediction
.
keys
()))
for
key
in
keys
:
if
key
.
startswith
(
'label:'
):
ppl_flag
=
True
origin_prediction
[
key
].
pop
(
'testing input'
,
None
)
new_key
=
key
.
replace
(
'label: '
,
''
)
origin_prediction
[
new_key
]
=
origin_prediction
.
pop
(
key
)
if
ppl_flag
:
results
[
'type'
]
=
'PPL'
result
[
'origin_prediction'
]
=
origin_prediction
result
[
'predictions'
]
=
str
(
predictions
[
i
])
result
[
'references'
]
=
str
(
references
[
i
])
result
[
'correct'
]
=
str
(
predictions
[
i
])
==
str
(
references
[
i
])
else
:
results
[
'type'
]
=
'GEN'
result
[
'prompt'
]
=
origin_prediction
[
'origin_prompt'
]
result
[
'origin_prediction'
]
=
pred_dicts
[
i
][
'prediction'
]
result
[
'predictions'
]
=
details
[
i
][
'pred'
]
result
[
'references'
]
=
details
[
i
][
'answers'
]
result
[
'correct'
]
=
details
[
i
][
'correct'
]
results
[
str
(
i
)]
=
result
return
results
def
calculate_bpb
(
self
,
pred_dicts
:
List
):
"""This function is used to calculate the BPB (Bits Per Byte) for the
data. The correct BPB is obtained directly from the values in the
'predictions' file. The incorrect BPB is the average of the remaining
BPB values for each sample under different labels after subtracting the
correct BPB. The calculation of BPB (Bits Per Byte) is similar to PPL,
with the difference that it computes the additional bits needed on
average, in terms of character length, to encode the true sequence
based on the predictions. This calculation involves applying a
weighting factor based on the ratio of words to characters.
Args:
pred_dicts (list): Contains a list of samples with each options
and BPB scores.
Returns:
dict: Contains correct and incorrect bpb.
"""
incorrect_bpb_list
=
[]
bpb_list
=
[]
for
pred_dict
in
pred_dicts
:
preds
=
{
key
:
value
for
key
,
value
in
pred_dict
.
items
()
if
key
.
startswith
(
'label: '
)
}
values
=
[]
for
item
in
preds
.
items
():
values
.
append
(
item
[
1
])
bpbs
=
[
value
[
'BPB'
]
for
value
in
values
]
incorrect_bpb_list
.
append
(
(
sum
(
bpbs
)
-
min
(
bpbs
))
/
(
len
(
bpbs
)
-
1
))
bpb_list
.
append
(
statistics
.
mean
(
bpbs
))
def
filters
(
origins
):
targets
=
[
target
for
target
in
origins
if
not
math
.
isnan
(
target
)]
return
targets
mean_incorrect
=
statistics
.
mean
(
filters
(
incorrect_bpb_list
))
mean_correct
=
statistics
.
mean
(
filters
(
bpb_list
))
return
100
*
mean_correct
,
100
*
mean_incorrect
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Score Calculator'
)
...
...
requirements/runtime.txt
View file @
dbb20b82
...
...
@@ -25,6 +25,7 @@ requests==2.31.0
rouge
rouge_chinese
rouge_score
sacrebleu
scikit_learn==1.2.1
seaborn
sentence_transformers==2.2.2
...
...
run.py
View file @
dbb20b82
...
...
@@ -123,6 +123,12 @@ def parse_args():
'Will be overrideen by the "retry" argument in the config.'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--dump-eval-details'
,
help
=
'Whether to dump the evaluation details, including the '
'correctness of each sample, bpb, etc.'
,
action
=
'store_true'
,
)
# set srun args
slurm_parser
=
parser
.
add_argument_group
(
'slurm_args'
)
parse_slurm_args
(
slurm_parser
)
...
...
@@ -300,6 +306,8 @@ def main():
if
args
.
dlc
or
args
.
slurm
or
cfg
.
get
(
'eval'
,
None
)
is
None
:
fill_eval_cfg
(
cfg
,
args
)
if
args
.
dump_eval_details
:
cfg
.
eval
.
runner
.
task
.
dump_details
=
True
if
args
.
partition
is
not
None
:
if
RUNNERS
.
get
(
cfg
.
eval
.
runner
.
type
)
==
SlurmRunner
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment