Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
655a807f
Unverified
Commit
655a807f
authored
Aug 21, 2023
by
philipwangOvO
Committed by
GitHub
Aug 21, 2023
Browse files
[Dataset] LongBench (#236)
Co-authored-by:
wangchonghua
<
wangchonghua@pjlab.org.cn
>
parent
c6a34949
Changes
66
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
169 additions
and
9 deletions
+169
-9
opencompass/datasets/longbench/longbench_repobench.py
opencompass/datasets/longbench/longbench_repobench.py
+26
-0
opencompass/datasets/longbench/longbench_trec.py
opencompass/datasets/longbench/longbench_trec.py
+30
-0
opencompass/datasets/longbench/longbench_trivia_qa.py
opencompass/datasets/longbench/longbench_trivia_qa.py
+26
-0
opencompass/datasets/longbench/longbench_vcsum.py
opencompass/datasets/longbench/longbench_vcsum.py
+21
-0
opencompass/models/openai_api.py
opencompass/models/openai_api.py
+64
-9
requirements/runtime.txt
requirements/runtime.txt
+2
-0
No files found.
opencompass/datasets/longbench/longbench_repobench.py
0 → 100644
View file @
655a807f
from
datasets
import
Dataset
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
..base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
LongBenchrepobenchDataset
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
dataset
=
load_dataset
(
**
kwargs
)
split
=
'test'
raw_data
=
[]
for
i
in
range
(
len
(
dataset
[
split
])):
question
=
dataset
[
split
][
'input'
][
i
]
context
=
dataset
[
split
][
'context'
][
i
]
answers
=
dataset
[
split
][
'answers'
][
i
]
raw_data
.
append
({
'input'
:
question
,
'context'
:
context
,
'answers'
:
answers
})
dataset
[
split
]
=
Dataset
.
from_list
(
raw_data
)
return
dataset
opencompass/datasets/longbench/longbench_trec.py
0 → 100644
View file @
655a807f
from
datasets
import
Dataset
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
..base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
LongBenchtrecDataset
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
dataset
=
load_dataset
(
**
kwargs
)
split
=
'test'
raw_data
=
[]
for
i
in
range
(
len
(
dataset
[
split
])):
question
=
dataset
[
split
][
'input'
][
i
]
context
=
dataset
[
split
][
'context'
][
i
]
answers
=
dataset
[
split
][
'answers'
][
i
]
all_classes
=
dataset
[
split
][
'all_classes'
][
i
]
raw_data
.
append
({
'input'
:
question
,
'context'
:
context
,
'all_labels'
:
{
'answers'
:
answers
,
'all_classes'
:
all_classes
}
})
dataset
[
split
]
=
Dataset
.
from_list
(
raw_data
)
return
dataset
opencompass/datasets/longbench/longbench_trivia_qa.py
0 → 100644
View file @
655a807f
from
datasets
import
Dataset
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
..base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
LongBenchtriviaqaDataset
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
dataset
=
load_dataset
(
**
kwargs
)
split
=
'test'
raw_data
=
[]
for
i
in
range
(
len
(
dataset
[
split
])):
question
=
dataset
[
split
][
'input'
][
i
]
context
=
dataset
[
split
][
'context'
][
i
]
answers
=
dataset
[
split
][
'answers'
][
i
]
raw_data
.
append
({
'input'
:
question
,
'context'
:
context
,
'answers'
:
answers
})
dataset
[
split
]
=
Dataset
.
from_list
(
raw_data
)
return
dataset
opencompass/datasets/longbench/longbench_vcsum.py
0 → 100644
View file @
655a807f
from
datasets
import
Dataset
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
..base
import
BaseDataset
@
LOAD_DATASET
.
register_module
()
class
LongBenchvcsumDataset
(
BaseDataset
):
@
staticmethod
def
load
(
**
kwargs
):
dataset
=
load_dataset
(
**
kwargs
)
split
=
'test'
raw_data
=
[]
for
i
in
range
(
len
(
dataset
[
split
])):
context
=
dataset
[
split
][
'context'
][
i
]
answers
=
dataset
[
split
][
'answers'
][
i
]
raw_data
.
append
({
'context'
:
context
,
'answers'
:
answers
})
dataset
[
split
]
=
Dataset
.
from_list
(
raw_data
)
return
dataset
opencompass/models/openai_api.py
View file @
655a807f
import
json
import
os
import
re
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
threading
import
Lock
from
typing
import
Dict
,
List
,
Optional
,
Union
import
jieba
import
requests
from
opencompass.registry
import
MODELS
...
...
@@ -42,6 +44,9 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
...
...
@@ -58,6 +63,7 @@ class OpenAI(BaseAPIModel):
org
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
meta_template
:
Optional
[
Dict
]
=
None
,
openai_api_base
:
str
=
OPENAI_API_BASE
,
mode
:
str
=
'none'
,
temperature
:
Optional
[
float
]
=
None
):
super
().
__init__
(
path
=
path
,
...
...
@@ -68,6 +74,8 @@ class OpenAI(BaseAPIModel):
import
tiktoken
self
.
tiktoken
=
tiktoken
self
.
temperature
=
temperature
assert
mode
in
[
'none'
,
'front'
,
'mid'
,
'rear'
]
self
.
mode
=
mode
if
isinstance
(
key
,
str
):
self
.
keys
=
[
os
.
getenv
(
'OPENAI_API_KEY'
)
if
key
==
'ENV'
else
key
]
...
...
@@ -137,6 +145,20 @@ class OpenAI(BaseAPIModel):
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
# max num token for gpt-3.5-turbo is 4097
context_window
=
4096
if
'32k'
in
self
.
path
:
context_window
=
32768
elif
'16k'
in
self
.
path
:
context_window
=
16384
elif
'gpt-4'
in
self
.
path
:
context_window
=
8192
# will leave 100 tokens as prompt buffer, triggered if input is str
if
isinstance
(
input
,
str
)
and
self
.
mode
!=
'none'
:
context_window
=
self
.
max_seq_len
input
=
self
.
bin_trim
(
input
,
context_window
-
100
-
max_out_len
)
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
...
...
@@ -151,15 +173,6 @@ class OpenAI(BaseAPIModel):
msg
[
'role'
]
=
'system'
messages
.
append
(
msg
)
# max num token for gpt-3.5-turbo is 4097
context_window
=
4096
if
'32k'
in
self
.
path
:
context_window
=
32768
elif
'16k'
in
self
.
path
:
context_window
=
16384
elif
'gpt-4'
in
self
.
path
:
context_window
=
8192
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len
=
min
(
max_out_len
,
context_window
-
self
.
get_token_len
(
str
(
input
))
-
100
)
...
...
@@ -251,3 +264,45 @@ class OpenAI(BaseAPIModel):
"""
enc
=
self
.
tiktoken
.
encoding_for_model
(
self
.
path
)
return
len
(
enc
.
encode
(
prompt
))
def
bin_trim
(
self
,
prompt
:
str
,
num_token
:
int
)
->
str
:
"""Get a suffix of prompt which is no longer than num_token tokens.
Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.
Returns:
str: The trimmed prompt.
"""
token_len
=
self
.
get_token_len
(
prompt
)
if
token_len
<=
num_token
:
return
prompt
pattern
=
re
.
compile
(
r
'[\u4e00-\u9fa5]'
)
if
pattern
.
search
(
prompt
):
words
=
list
(
jieba
.
cut
(
prompt
,
cut_all
=
False
))
else
:
words
=
prompt
.
split
(
' '
)
l
,
r
=
1
,
len
(
words
)
while
l
+
2
<
r
:
mid
=
(
l
+
r
)
//
2
if
self
.
mode
==
'front'
:
cur_prompt
=
' '
.
join
(
words
[
-
mid
:])
elif
self
.
mode
==
'mid'
:
cur_prompt
=
' '
.
join
(
words
[:
mid
])
+
' '
.
join
(
words
[
-
mid
:])
elif
self
.
mode
==
'rear'
:
cur_prompt
=
' '
.
join
(
words
[:
mid
])
if
self
.
get_token_len
(
cur_prompt
)
<=
num_token
:
l
=
mid
# noqa: E741
else
:
r
=
mid
if
self
.
mode
==
'front'
:
prompt
=
' '
.
join
(
words
[
-
l
:])
elif
self
.
mode
==
'mid'
:
prompt
=
' '
.
join
(
words
[:
l
])
+
' '
.
join
(
words
[
-
l
:])
elif
self
.
mode
==
'rear'
:
prompt
=
' '
.
join
(
words
[:
l
])
return
prompt
requirements/runtime.txt
View file @
655a807f
...
...
@@ -7,6 +7,7 @@ datasets>=2.12.0
evaluate>=0.3.0
fairscale
faiss_gpu==1.7.2
fuzzywuzzy
jieba
mmengine>=0.8.2
nltk==3.8
...
...
@@ -16,6 +17,7 @@ pandas<2.0.0
rank_bm25==0.2.2
rapidfuzz
requests==2.31.0
rouge
rouge_score
scikit_learn==1.2.1
sentence_transformers==2.2.2
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment