Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
cbe9fe2c
Commit
cbe9fe2c
authored
Jul 05, 2023
by
Ezra-Yu
Committed by
gaotong
Jul 05, 2023
Browse files
Add Release Contraibution
parent
36f11110
Changes
65
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
154 additions
and
0 deletions
+154
-0
opencompass/openicl/icl_retriever/icl_bm25_retriever.py
opencompass/openicl/icl_retriever/icl_bm25_retriever.py
+74
-0
opencompass/openicl/icl_retriever/icl_random_retriever.py
opencompass/openicl/icl_retriever/icl_random_retriever.py
+40
-0
opencompass/openicl/icl_retriever/icl_zero_retriever.py
opencompass/openicl/icl_retriever/icl_zero_retriever.py
+26
-0
opencompass/openicl/utils/__init__.py
opencompass/openicl/utils/__init__.py
+1
-0
opencompass/utils/logging.py
opencompass/utils/logging.py
+13
-0
No files found.
opencompass/openicl/icl_retriever/icl_bm25_retriever.py
0 → 100644
View file @
cbe9fe2c
"""BM25 Retriever."""
from
typing
import
List
,
Optional
import
numpy
as
np
from
nltk.tokenize
import
word_tokenize
from
rank_bm25
import
BM25Okapi
from
tqdm
import
trange
from
opencompass.openicl.icl_retriever
import
BaseRetriever
from
opencompass.openicl.utils.logging
import
get_logger
from
opencompass.registry
import
ICL_RETRIEVERS
logger
=
get_logger
(
__name__
)
@
ICL_RETRIEVERS
.
register_module
()
class
BM25Retriever
(
BaseRetriever
):
"""BM25 Retriever. In information retrieval, Okapi BM25 (BM is an
abbreviation of best matching) is a ranking function used by search engines
to estimate the relevance of documents to a given search query. You can
find more details in https://en.wikipedia.org/wiki/Okapi_BM25. Each in-
context example of the test prompts is retrieved by the BM25 Algorithm.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '
\n
'.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to '
\n
'.
ice_num (`Optional[int]`): The number of in-context example template
when origin `PromptTemplate` is provided. Defaults to 1.
index_split (`Optional[str]`): The split of the dataset to retrieve the
in-context example index, used when `dataset_reader.dataset` is an
instance of `datasets.Dataset`. Defaults to 'train'.
test_split (`Optional[str]`): The split of the dataset to retrieve the
in-context example, used when `dataset_reader.dataset` is an
instance of `datasets.Dataset`. Defaults to 'test'.
"""
bm25
=
None
index_corpus
=
None
test_corpus
=
None
def
__init__
(
self
,
dataset
,
ice_separator
:
Optional
[
str
]
=
'
\n
'
,
ice_eos_token
:
Optional
[
str
]
=
'
\n
'
,
ice_num
:
Optional
[
int
]
=
1
)
->
None
:
super
().
__init__
(
dataset
,
ice_separator
,
ice_eos_token
,
ice_num
)
self
.
index_corpus
=
[
word_tokenize
(
data
)
for
data
in
self
.
dataset_reader
.
generate_input_field_corpus
(
self
.
index_ds
)
]
self
.
bm25
=
BM25Okapi
(
self
.
index_corpus
)
self
.
test_corpus
=
[
word_tokenize
(
data
)
for
data
in
self
.
dataset_reader
.
generate_input_field_corpus
(
self
.
test_ds
)
]
def
retrieve
(
self
)
->
List
[
List
]:
"""Retrieve the in-context example index for each test example."""
rtr_idx_list
=
[]
logger
.
info
(
'Retrieving data for test set...'
)
for
idx
in
trange
(
len
(
self
.
test_corpus
),
disable
=
not
self
.
is_main_process
):
query
=
self
.
test_corpus
[
idx
]
scores
=
self
.
bm25
.
get_scores
(
query
)
near_ids
=
list
(
np
.
argsort
(
scores
)[::
-
1
][:
self
.
ice_num
])
near_ids
=
[
int
(
a
)
for
a
in
near_ids
]
rtr_idx_list
.
append
(
near_ids
)
return
rtr_idx_list
opencompass/openicl/icl_retriever/icl_random_retriever.py
0 → 100644
View file @
cbe9fe2c
"""Random Retriever."""
from
typing
import
Optional
import
numpy
as
np
from
tqdm
import
trange
from
opencompass.openicl.icl_retriever
import
BaseRetriever
from
opencompass.openicl.utils.logging
import
get_logger
logger
=
get_logger
(
__name__
)
class
RandomRetriever
(
BaseRetriever
):
"""Random Retriever. Each in-context example of the test prompts is
retrieved in a random way.
**WARNING**: This class has not been tested thoroughly. Please use it with
caution.
"""
def
__init__
(
self
,
dataset
,
ice_separator
:
Optional
[
str
]
=
'
\n
'
,
ice_eos_token
:
Optional
[
str
]
=
'
\n
'
,
ice_num
:
Optional
[
int
]
=
1
,
seed
:
Optional
[
int
]
=
43
)
->
None
:
super
().
__init__
(
dataset
,
ice_separator
,
ice_eos_token
,
ice_num
)
self
.
seed
=
seed
def
retrieve
(
self
):
np
.
random
.
seed
(
self
.
seed
)
num_idx
=
len
(
self
.
index_ds
)
rtr_idx_list
=
[]
logger
.
info
(
'Retrieving data for test set...'
)
for
_
in
trange
(
len
(
self
.
test_ds
),
disable
=
not
self
.
is_main_process
):
idx_list
=
np
.
random
.
choice
(
num_idx
,
self
.
ice_num
,
replace
=
False
).
tolist
()
rtr_idx_list
.
append
(
idx_list
)
return
rtr_idx_list
opencompass/openicl/icl_retriever/icl_zero_retriever.py
0 → 100644
View file @
cbe9fe2c
"""Zeroshot Retriever."""
from
typing
import
List
,
Optional
from
opencompass.openicl.icl_retriever
import
BaseRetriever
from
opencompass.registry
import
ICL_RETRIEVERS
@
ICL_RETRIEVERS
.
register_module
()
class
ZeroRetriever
(
BaseRetriever
):
"""Zeroshot Retriever. The retriever returns empty list for all queries.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to ''.
"""
def
__init__
(
self
,
dataset
,
ice_eos_token
:
Optional
[
str
]
=
''
)
->
None
:
super
().
__init__
(
dataset
,
''
,
ice_eos_token
,
0
)
def
retrieve
(
self
)
->
List
[
List
]:
rtr_idx_list
=
[[]
for
_
in
range
(
len
(
self
.
test_ds
))]
return
rtr_idx_list
opencompass/openicl/utils/__init__.py
0 → 100644
View file @
cbe9fe2c
from
.logging
import
*
opencompass/utils/logging.py
0 → 100644
View file @
cbe9fe2c
from
mmengine.logging
import
MMLogger
def
get_logger
(
log_level
=
'INFO'
)
->
MMLogger
:
"""Get the logger for OpenCompass.
Args:
log_level (str): The log level. Default: 'INFO'. Choices are 'DEBUG',
'INFO', 'WARNING', 'ERROR', 'CRITICAL'.
"""
return
MMLogger
.
get_instance
(
'OpenCompass'
,
logger_name
=
'OpenCompass'
,
log_level
=
log_level
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment