Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
cbe9fe2c
Commit
cbe9fe2c
authored
Jul 05, 2023
by
Ezra-Yu
Committed by
gaotong
Jul 05, 2023
Browse files
Add Release Contraibution
parent
36f11110
Changes
65
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
608 additions
and
0 deletions
+608
-0
configs/datasets/summscreen/summscreen_gen_997ee2.py
configs/datasets/summscreen/summscreen_gen_997ee2.py
+35
-0
configs/datasets/winograd/winograd_ppl.py
configs/datasets/winograd/winograd_ppl.py
+4
-0
configs/datasets/z_bench/z_bench_gen_61db0a.py
configs/datasets/z_bench/z_bench_gen_61db0a.py
+28
-0
docs/en/Makefile
docs/en/Makefile
+20
-0
docs/en/prompt/prompt_template.md
docs/en/prompt/prompt_template.md
+1
-0
docs/en/user_guides/models.md
docs/en/user_guides/models.md
+1
-0
docs/zh_cn/_static/css/readthedocs.css
docs/zh_cn/_static/css/readthedocs.css
+62
-0
docs/zh_cn/_static/image/logo.png
docs/zh_cn/_static/image/logo.png
+0
-0
docs/zh_cn/_templates/callable.rst
docs/zh_cn/_templates/callable.rst
+14
-0
docs/zh_cn/notes/contribution_guide.md
docs/zh_cn/notes/contribution_guide.md
+67
-0
opencompass/datasets/arc.py
opencompass/datasets/arc.py
+45
-0
opencompass/datasets/flores.py
opencompass/datasets/flores.py
+36
-0
opencompass/datasets/qasper.py
opencompass/datasets/qasper.py
+43
-0
opencompass/datasets/qaspercut.py
opencompass/datasets/qaspercut.py
+53
-0
opencompass/datasets/safety.py
opencompass/datasets/safety.py
+23
-0
opencompass/datasets/triviaqarc.py
opencompass/datasets/triviaqarc.py
+58
-0
opencompass/datasets/winogrande.py
opencompass/datasets/winogrande.py
+44
-0
opencompass/datasets/xcopa.py
opencompass/datasets/xcopa.py
+29
-0
opencompass/openicl/icl_evaluator/icl_aucroc_evaluator.py
opencompass/openicl/icl_evaluator/icl_aucroc_evaluator.py
+41
-0
opencompass/openicl/icl_inferencer/__init__.py
opencompass/openicl/icl_inferencer/__init__.py
+4
-0
No files found.
configs/datasets/summscreen/summscreen_gen_997ee2.py
0 → 100644
View file @
cbe9fe2c
from
opencompass.openicl.icl_prompt_template
import
PromptTemplate
from
opencompass.openicl.icl_retriever
import
ZeroRetriever
from
opencompass.openicl.icl_inferencer
import
GenInferencer
from
opencompass.openicl.icl_evaluator
import
BleuEvaluator
from
opencompass.datasets
import
SummScreenDataset
summscreen_reader_cfg
=
dict
(
input_columns
=
'content'
,
output_column
=
'summary'
,
train_split
=
'dev'
,
test_split
=
'dev'
)
summscreen_infer_cfg
=
dict
(
prompt_template
=
dict
(
type
=
PromptTemplate
,
template
=
"Please summarize the following English report in English:{content}
\n
{summary}."
),
retriever
=
dict
(
type
=
ZeroRetriever
),
inferencer
=
dict
(
type
=
GenInferencer
,
batch_size
=
4
,
max_out_len
=
500
,
max_seq_len
=
8192
))
summscreen_eval_cfg
=
dict
(
evaluator
=
dict
(
type
=
BleuEvaluator
),
pred_postprocessor
=
dict
(
type
=
'general_cn'
),
dataset_postprocessor
=
dict
(
type
=
'general_cn'
))
summscreen_datasets
=
[
dict
(
type
=
SummScreenDataset
,
path
=
'./data/SummScreen/'
,
abbr
=
'SummScreen'
,
reader_cfg
=
summscreen_reader_cfg
,
infer_cfg
=
summscreen_infer_cfg
,
eval_cfg
=
summscreen_eval_cfg
)
]
configs/datasets/winograd/winograd_ppl.py
0 → 100644
View file @
cbe9fe2c
from
mmengine.config
import
read_base
with
read_base
():
from
.winograd_ppl_c1c427
import
winograd_datasets
# noqa: F401, F403
configs/datasets/z_bench/z_bench_gen_61db0a.py
0 → 100644
View file @
cbe9fe2c
from
opencompass.openicl.icl_prompt_template
import
PromptTemplate
from
opencompass.openicl.icl_retriever
import
ZeroRetriever
from
opencompass.openicl.icl_inferencer
import
GenInferencer
from
opencompass.datasets
import
HFDataset
z_bench_reader_cfg
=
dict
(
ds_size
=
4
,
input_columns
=
[
'text'
],
output_column
=
'category'
,
train_split
=
'test'
)
z_bench_infer_cfg
=
dict
(
prompt_template
=
dict
(
type
=
PromptTemplate
,
template
=
dict
(
round
=
[
dict
(
role
=
"HUMAN"
,
prompt
=
"{text}"
)]),
),
retriever
=
dict
(
type
=
ZeroRetriever
),
inferencer
=
dict
(
type
=
GenInferencer
))
z_bench_dataset
=
dict
(
type
=
HFDataset
,
path
=
'/mnt/petrelfs/gaotong/llm_eval/openagieval_dataset/eval_datasets/z_bench'
,
data_dir
=
'/mnt/petrelfs/gaotong/llm_eval/openagieval_dataset/eval_datasets/z_bench'
,
name
=
'question'
,
reader_cfg
=
z_bench_reader_cfg
,
infer_cfg
=
z_bench_infer_cfg
)
docs/en/Makefile
0 → 100644
View file @
cbe9fe2c
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS
?=
SPHINXBUILD
?=
sphinx-build
SOURCEDIR
=
.
BUILDDIR
=
_build
# Put it first so that "make" without argument is like "make help".
help
:
@
$(SPHINXBUILD)
-M
help
"
$(SOURCEDIR)
"
"
$(BUILDDIR)
"
$(SPHINXOPTS)
$(O)
.PHONY
:
help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%
:
Makefile
@
$(SPHINXBUILD)
-M
$@
"
$(SOURCEDIR)
"
"
$(BUILDDIR)
"
$(SPHINXOPTS)
$(O)
docs/en/prompt/prompt_template.md
0 → 100644
View file @
cbe9fe2c
# Prompt Template
\ No newline at end of file
docs/en/user_guides/models.md
0 → 100644
View file @
cbe9fe2c
# Prepare Models
docs/zh_cn/_static/css/readthedocs.css
0 → 100644
View file @
cbe9fe2c
.header-logo
{
background-image
:
url("../image/logo.png")
;
background-size
:
183px
50px
;
height
:
50px
;
width
:
183px
;
}
@media
screen
and
(
min-width
:
1100px
)
{
.header-logo
{
top
:
-12px
;
}
}
pre
{
white-space
:
pre
;
}
@media
screen
and
(
min-width
:
2000px
)
{
.pytorch-content-left
{
width
:
1200px
;
margin-left
:
30px
;
}
article
.pytorch-article
{
max-width
:
1200px
;
}
.pytorch-breadcrumbs-wrapper
{
width
:
1200px
;
}
.pytorch-right-menu.scrolling-fixed
{
position
:
fixed
;
top
:
45px
;
left
:
1580px
;
}
}
article
.pytorch-article
section
code
{
padding
:
.2em
.4em
;
background-color
:
#f3f4f7
;
border-radius
:
5px
;
}
/* Disable the change in tables */
article
.pytorch-article
section
table
code
{
padding
:
unset
;
background-color
:
unset
;
border-radius
:
unset
;
}
table
.autosummary
td
{
width
:
50%
}
img
.align-center
{
display
:
block
;
margin-left
:
auto
;
margin-right
:
auto
;
}
article
.pytorch-article
p
.rubric
{
font-weight
:
bold
;
}
docs/zh_cn/_static/image/logo.png
0 → 100644
View file @
cbe9fe2c
12.4 KB
docs/zh_cn/_templates/callable.rst
0 → 100644
View file @
cbe9fe2c
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
:special-members: __call__
..
autogenerated from _templates/callable.rst
note it does not have :inherited-members:
docs/zh_cn/notes/contribution_guide.md
0 → 100644
View file @
cbe9fe2c
# 为 OpenCompass 做贡献
-
[
为OpenCompass做贡献
](
#为opencompass做贡献
)
-
[
工作流程
](
#工作流程
)
-
[
代码风格
](
#代码风格
)
-
[
Python
](
#python
)
-
[
预提交钩子 (Pre-commit Hook)
](
#预提交钩子-pre-commit-hook
)
感谢你对于OpenCompass的贡献!我们欢迎各种形式的贡献,包括但不限于以下几点。
-
修改错别字或修复bug
-
添加文档或将文档翻译成其它语言
-
添加新功能和组件
## 工作流程
我们建议潜在的贡献者遵循以下的贡献工作流程。
1.
Fork并拉取最新的OpenCompass仓库,按照
[
开始使用
](
https://OpenCompass.readthedocs.io/en/latest/get_started.html
)
来设置环境。
2.
检出一个新的分支(
**不要使用master或dev分支来创建PR**
)
```
bash
git checkout
-b
xxxx
# xxxx 是新分支的名称
```
3.
编辑相关文件,并且遵循下面提到的代码风格
4.
使用
[
预提交钩子
](
https://pre-commit.com/
)
来检查和格式化你的更改。
5.
提交你的更改
6.
创建一个带有相关信息的PR
## 代码风格
### Python
我们采用
[
PEP8
](
https://www.python.org/dev/peps/pep-0008/
)
作为首选的代码风格。
我们使用以下工具进行linting和格式化:
-
[
flake8
](
https://github.com/PyCQA/flake8
)
: 一个围绕一些linter工具的封装器。
-
[
isort
](
https://github.com/timothycrosley/isort
)
: 一个用于排序Python导入的实用程序。
-
[
yapf
](
https://github.com/google/yapf
)
: 一个Python文件的格式化器。
-
[
codespell
](
https://github.com/codespell-project/codespell
)
: 一个Python实用程序,用于修复文本文件中常见的拼写错误。
-
[
mdformat
](
https://github.com/executablebooks/mdformat
)
: mdformat是一个有明确定义的Markdown格式化程序,可以用来在Markdown文件中强制执行一致的样式。
-
[
docformatter
](
https://github.com/myint/docformatter
)
: 一个格式化docstring的工具。
yapf和isort的样式配置可以在
[
setup.cfg
](
https://github.com/OpenCompass/blob/main/setup.cfg
)
中找到。
## 预提交钩子 (Pre-commit Hook)
我们使用
[
预提交钩子
](
https://pre-commit.com/
)
用于在每次提交时自动检查与格式化
`flake8`
、
`yapf`
、
`isort`
、
`trailing whitespaces`
、
`markdown files`
,
修复
`end-of-files`
、
`double-quoted-strings`
、
`python-encoding-pragma`
、
`mixed-line-ending`
,并自动排序
`requirments.txt`
。预提交钩子的配置存储在
[
.pre-commit-config
](
)中。
在你克隆仓库后,你需要安装并初始化预提交钩子。
```
shell
pip
install
-U
pre-commit
```
从仓库文件夹运行
```
shell
pre-commit
install
```
之后,在每次提交时都会强制执行代码 linters 和格式化器。
> 在你创建PR前,确保你的代码通过了 lint 检查并被 yapf 格式化。
\ No newline at end of file
opencompass/datasets/arc.py
0 → 100644
View file @
cbe9fe2c
import
json
from
datasets
import
Dataset
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
ARCDataset
(
BaseDataset
):
@
staticmethod
def
load
(
path
:
str
):
with
open
(
path
,
'r'
,
errors
=
'ignore'
)
as
in_f
:
rows
=
[]
for
i
,
line
in
enumerate
(
in_f
):
sample
=
json
.
loads
(
line
.
strip
())
answerKey
=
sample
[
'answerKey'
]
sample
=
sample
[
'question'
]
question
=
sample
[
'stem'
]
choices
=
sample
[
'choices'
]
if
len
(
choices
)
!=
4
:
continue
textA
=
choices
[
0
][
'text'
]
textB
=
choices
[
1
][
'text'
]
textC
=
choices
[
2
][
'text'
]
textD
=
choices
[
3
][
'text'
]
rows
.
append
({
'question'
:
question
,
'answerKey'
:
answerKey
,
'textA'
:
textA
,
'textB'
:
textB
,
'textC'
:
textC
,
'textD'
:
textD
})
dataset
=
Dataset
.
from_dict
({
'question'
:
[
row
[
'question'
]
for
row
in
rows
],
'answerKey'
:
[
row
[
'answerKey'
]
for
row
in
rows
],
'textA'
:
[
row
[
'textA'
]
for
row
in
rows
],
'textB'
:
[
row
[
'textB'
]
for
row
in
rows
],
'textC'
:
[
row
[
'textC'
]
for
row
in
rows
],
'textD'
:
[
row
[
'textD'
]
for
row
in
rows
]
})
return
dataset
opencompass/datasets/flores.py
0 → 100644
View file @
cbe9fe2c
import
re
from
datasets
import
DatasetDict
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
,
TEXT_POSTPROCESSORS
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
FloresFirst100Dataset
(
BaseDataset
):
@
staticmethod
def
load
(
name
):
return
DatasetDict
({
'dev'
:
load_dataset
(
path
=
'facebook/flores'
,
name
=
name
,
split
=
'dev'
),
'devtest'
:
load_dataset
(
path
=
'facebook/flores'
,
name
=
name
,
split
=
'devtest[:100]'
)
})
@
TEXT_POSTPROCESSORS
.
register_module
(
'flores'
)
def
flores_postprocess
(
text
:
str
)
->
str
:
text
=
text
.
strip
().
split
(
'
\n
'
)[
0
]
return
text
@
TEXT_POSTPROCESSORS
.
register_module
(
'flores-chinese'
)
def
flores_postprocess_chinese
(
text
:
str
)
->
str
:
import
jieba
truncated_text
=
text
.
strip
().
split
(
'
\n
'
)[
0
]
cleaned_text
=
re
.
sub
(
r
'\s+'
,
' '
,
truncated_text
).
strip
()
cleaned_text
=
' '
.
join
(
jieba
.
cut
(
cleaned_text
))
return
cleaned_text
opencompass/datasets/qasper.py
0 → 100644
View file @
cbe9fe2c
from
datasets
import
Dataset
,
DatasetDict
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
QASPERDataset
(
BaseDataset
):
@
staticmethod
def
load
(
path
:
str
):
import
json
import
os
dataset_dict
=
DatasetDict
()
split
=
'dev'
dev_list
=
[]
dev
=
os
.
path
.
join
(
path
,
'qasper-dev-v0.3.json'
)
with
open
(
dev
,
'r'
)
as
f
:
dev_json
=
json
.
load
(
f
)
for
article_id
in
dev_json
.
keys
():
full_article
=
'
\n
'
.
join
([
(
x
[
'section_name'
]
if
x
[
'section_name'
]
else
''
)
+
'
\n
'
+
'
\n
'
.
join
(
x
[
'paragraphs'
])
+
'
\n
'
for
x
in
dev_json
[
article_id
][
'full_text'
]
])
for
qa
in
dev_json
[
article_id
][
'qas'
]:
question
=
qa
[
'question'
]
answers
=
[]
for
x
in
qa
[
'answers'
]:
answers
.
extend
(
x
[
'answer'
][
'extractive_spans'
])
if
answers
:
dev_list
.
append
({
'answer'
:
answers
,
'question'
:
question
,
'evidence'
:
full_article
,
})
else
:
continue
dataset_dict
[
split
]
=
Dataset
.
from_list
(
dev_list
)
return
dataset_dict
opencompass/datasets/qaspercut.py
0 → 100644
View file @
cbe9fe2c
from
datasets
import
Dataset
,
DatasetDict
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
QASPERCUTDataset
(
BaseDataset
):
@
staticmethod
def
load
(
path
:
str
):
import
json
import
os
dataset_dict
=
DatasetDict
()
split
=
'dev'
dev_list
=
[]
dev
=
os
.
path
.
join
(
path
,
'qasper-dev-v0.3.json'
)
with
open
(
dev
,
'r'
)
as
f
:
dev_json
=
json
.
load
(
f
)
for
article_id
in
dev_json
.
keys
():
full_article
=
'
\n
'
.
join
([
(
x
[
'section_name'
]
if
x
[
'section_name'
]
else
''
)
+
'
\n
'
+
'
\n
'
.
join
(
x
[
'paragraphs'
])
+
'
\n
'
for
x
in
dev_json
[
article_id
][
'full_text'
]
])
for
qa
in
dev_json
[
article_id
][
'qas'
]:
question
=
qa
[
'question'
]
answers
=
[]
clues
=
[]
for
x
in
qa
[
'answers'
]:
answers
.
extend
(
x
[
'answer'
][
'extractive_spans'
])
clues
.
extend
(
x
[
'answer'
][
'evidence'
])
evis
=
[
full_article
.
find
(
clue
)
for
clue
in
clues
]
+
[
100000000
]
evi
=
min
(
evis
)
if
evi
==
-
1
or
evi
==
100000000
:
evi
=
0
if
answers
:
dev_list
.
append
({
'answer'
:
answers
,
'question'
:
question
,
'evidence'
:
full_article
[
evi
:],
})
else
:
continue
dataset_dict
[
split
]
=
Dataset
.
from_list
(
dev_list
)
return
dataset_dict
opencompass/datasets/safety.py
0 → 100644
View file @
cbe9fe2c
from
datasets
import
Dataset
,
DatasetDict
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
SafetyDataset
(
BaseDataset
):
@
staticmethod
def
load
(
path
):
dataset
=
DatasetDict
()
data_list
=
list
()
idx
=
0
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
if
line
.
strip
():
data_list
.
append
({
'idx'
:
idx
,
'prompt'
:
line
.
strip
()})
idx
+=
1
dataset
[
'test'
]
=
Dataset
.
from_list
(
data_list
)
opencompass/datasets/triviaqarc.py
0 → 100644
View file @
cbe9fe2c
from
datasets
import
Dataset
,
DatasetDict
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
TriviaQArcDataset
(
BaseDataset
):
@
staticmethod
def
load
(
path
:
str
):
import
json
import
os
dataset_dict
=
DatasetDict
()
split
=
'dev'
dev_list
=
[]
web_dev
=
os
.
path
.
join
(
path
,
'qa'
,
'verified-web-dev.json'
)
with
open
(
web_dev
,
'r'
)
as
f
:
web_dev_json
=
json
.
load
(
f
)
for
x
in
web_dev_json
[
'Data'
]:
cand_answers
=
x
[
'Answer'
][
'Aliases'
]
+
x
[
'Answer'
][
'HumanAnswers'
]
question
=
x
[
'Question'
]
evidence
=
''
if
x
[
'SearchResults'
]:
x_path
=
os
.
path
.
join
(
path
,
'evidence'
,
'web'
,
x
[
'SearchResults'
][
0
][
'Filename'
])
with
open
(
x_path
,
'r'
)
as
f
:
evidence
=
f
.
read
(
100000
)
dev_list
.
append
({
'answer'
:
cand_answers
,
'question'
:
question
,
'evidence'
:
evidence
,
})
wiki_dev
=
os
.
path
.
join
(
path
,
'qa'
,
'verified-wikipedia-dev.json'
)
with
open
(
wiki_dev
,
'r'
)
as
f
:
wiki_dev_json
=
json
.
load
(
f
)
for
x
in
wiki_dev_json
[
'Data'
]:
cand_answers
=
x
[
'Answer'
][
'Aliases'
]
question
=
x
[
'Question'
]
evidence
=
''
if
x
[
'EntityPages'
]:
x_path
=
os
.
path
.
join
(
path
,
'evidence'
,
'wikipedia'
,
x
[
'EntityPages'
][
0
][
'Filename'
])
with
open
(
x_path
,
'r'
)
as
f
:
evidence
=
f
.
read
(
100000
)
dev_list
.
append
({
'answer'
:
cand_answers
,
'question'
:
question
,
'evidence'
:
evidence
,
})
dataset_dict
[
split
]
=
Dataset
.
from_list
(
dev_list
)
return
dataset_dict
opencompass/datasets/winogrande.py
0 → 100644
View file @
cbe9fe2c
from
datasets
import
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
winograndeDataset
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
dataset
=
load_dataset
(
**
kwargs
)
def
preprocess
(
example
):
prompt
=
example
.
pop
(
'sentence'
)
example
[
'opt1'
]
=
prompt
.
replace
(
'_'
,
example
.
pop
(
'option1'
))
example
[
'opt2'
]
=
prompt
.
replace
(
'_'
,
example
.
pop
(
'option2'
))
return
example
return
dataset
.
map
(
preprocess
)
@
LOAD_DATASET
.
register_module
()
class
winograndeDataset_V2
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
dataset
=
load_dataset
(
**
kwargs
)
def
preprocess
(
example
):
prompt
=
example
.
pop
(
'sentence'
)
example
[
'opt1'
]
=
prompt
.
replace
(
'_'
,
example
.
pop
(
'option1'
))
example
[
'opt2'
]
=
prompt
.
replace
(
'_'
,
example
.
pop
(
'option2'
))
answer
=
example
.
pop
(
'answer'
)
if
answer
==
''
:
example
[
'label'
]
=
'NULL'
else
:
example
[
'label'
]
=
' AB'
[
int
(
answer
)]
return
example
return
dataset
.
map
(
preprocess
)
opencompass/datasets/xcopa.py
0 → 100644
View file @
cbe9fe2c
from
datasets
import
concatenate_datasets
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
.base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
XCOPADataset
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
path
=
kwargs
.
get
(
'path'
,
None
)
lans
=
[
'et'
,
'ht'
,
'it'
,
'id'
,
'qu'
,
'sw'
,
'zh'
,
'ta'
,
'th'
,
'tr'
,
'vi'
,
'translation-et'
,
'translation-ht'
,
'translation-it'
,
'translation-id'
,
'translation-sw'
,
'translation-zh'
,
'translation-ta'
,
'translation-th'
,
'translation-tr'
,
'translation-vi'
]
datasets
=
[]
for
lan
in
lans
:
dataset
=
load_dataset
(
path
,
lan
)[
'validation'
]
datasets
.
append
(
dataset
)
combined_dataset
=
concatenate_datasets
(
datasets
)
return
combined_dataset
opencompass/openicl/icl_evaluator/icl_aucroc_evaluator.py
0 → 100644
View file @
cbe9fe2c
from
typing
import
List
import
numpy
as
np
from
sklearn.metrics
import
roc_auc_score
from
opencompass.registry
import
ICL_EVALUATORS
from
.icl_base_evaluator
import
BaseEvaluator
@
ICL_EVALUATORS
.
register_module
()
class
AUCROCEvaluator
(
BaseEvaluator
):
"""Calculate AUC-ROC scores and accuracy according the prediction.
For some dataset, the accuracy cannot reveal the difference between
models because of the saturation. AUC-ROC scores can further exam
model abilities to distinguish different labels. More details can refer to
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
"""
# noqa
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
score
(
self
,
predictions
:
List
,
references
:
List
)
->
dict
:
"""Calculate scores and accuracy.
Args:
predictions (List): List of probabilities for each class of each
sample.
references (List): List of target labels for each sample.
Returns:
dict: calculated scores.
"""
if
len
(
predictions
)
!=
len
(
references
):
return
{
'error'
:
'predictions and references have different length.'
}
auc_score
=
roc_auc_score
(
references
,
np
.
array
(
predictions
)[:,
1
])
accuracy
=
sum
(
references
==
np
.
argmax
(
predictions
,
axis
=
1
))
/
len
(
references
)
return
dict
(
auc_score
=
auc_score
*
100
,
accuracy
=
accuracy
*
100
)
opencompass/openicl/icl_inferencer/__init__.py
0 → 100644
View file @
cbe9fe2c
from
.icl_base_inferencer
import
BaseInferencer
from
.icl_gen_inferencer
import
GenInferencer
from
.icl_ppl_inferencer
import
PPLInferencer
from
.icl_clp_inferencer
import
CLPInferencer
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment