Unverified Commit 655a807f authored by philipwangOvO's avatar philipwangOvO Committed by GitHub
Browse files

[Dataset] LongBench (#236)


Co-authored-by: default avatarwangchonghua <wangchonghua@pjlab.org.cn>
parent c6a34949
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchrepobenchDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
question = dataset[split]['input'][i]
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
raw_data.append({
'input': question,
'context': context,
'answers': answers
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchtrecDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
question = dataset[split]['input'][i]
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
all_classes = dataset[split]['all_classes'][i]
raw_data.append({
'input': question,
'context': context,
'all_labels': {
'answers': answers,
'all_classes': all_classes
}
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchtriviaqaDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
question = dataset[split]['input'][i]
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
raw_data.append({
'input': question,
'context': context,
'answers': answers
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchvcsumDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
raw_data.append({'context': context, 'answers': answers})
dataset[split] = Dataset.from_list(raw_data)
return dataset
import json import json
import os import os
import re
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from threading import Lock from threading import Lock
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import jieba
import requests import requests
from opencompass.registry import MODELS from opencompass.registry import MODELS
...@@ -42,6 +44,9 @@ class OpenAI(BaseAPIModel): ...@@ -42,6 +44,9 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions. wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'. 'https://api.openai.com/v1/chat/completions'.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'.
temperature (float, optional): What sampling temperature to use. temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()` If not None, will override the temperature in the `generate()`
call. Defaults to None. call. Defaults to None.
...@@ -58,6 +63,7 @@ class OpenAI(BaseAPIModel): ...@@ -58,6 +63,7 @@ class OpenAI(BaseAPIModel):
org: Optional[Union[str, List[str]]] = None, org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
openai_api_base: str = OPENAI_API_BASE, openai_api_base: str = OPENAI_API_BASE,
mode: str = 'none',
temperature: Optional[float] = None): temperature: Optional[float] = None):
super().__init__(path=path, super().__init__(path=path,
...@@ -68,6 +74,8 @@ class OpenAI(BaseAPIModel): ...@@ -68,6 +74,8 @@ class OpenAI(BaseAPIModel):
import tiktoken import tiktoken
self.tiktoken = tiktoken self.tiktoken = tiktoken
self.temperature = temperature self.temperature = temperature
assert mode in ['none', 'front', 'mid', 'rear']
self.mode = mode
if isinstance(key, str): if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
...@@ -137,6 +145,20 @@ class OpenAI(BaseAPIModel): ...@@ -137,6 +145,20 @@ class OpenAI(BaseAPIModel):
""" """
assert isinstance(input, (str, PromptList)) assert isinstance(input, (str, PromptList))
# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.path:
context_window = 32768
elif '16k' in self.path:
context_window = 16384
elif 'gpt-4' in self.path:
context_window = 8192
# will leave 100 tokens as prompt buffer, triggered if input is str
if isinstance(input, str) and self.mode != 'none':
context_window = self.max_seq_len
input = self.bin_trim(input, context_window - 100 - max_out_len)
if isinstance(input, str): if isinstance(input, str):
messages = [{'role': 'user', 'content': input}] messages = [{'role': 'user', 'content': input}]
else: else:
...@@ -151,15 +173,6 @@ class OpenAI(BaseAPIModel): ...@@ -151,15 +173,6 @@ class OpenAI(BaseAPIModel):
msg['role'] = 'system' msg['role'] = 'system'
messages.append(msg) messages.append(msg)
# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.path:
context_window = 32768
elif '16k' in self.path:
context_window = 16384
elif 'gpt-4' in self.path:
context_window = 8192
# Hold out 100 tokens due to potential errors in tiktoken calculation # Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len = min( max_out_len = min(
max_out_len, context_window - self.get_token_len(str(input)) - 100) max_out_len, context_window - self.get_token_len(str(input)) - 100)
...@@ -251,3 +264,45 @@ class OpenAI(BaseAPIModel): ...@@ -251,3 +264,45 @@ class OpenAI(BaseAPIModel):
""" """
enc = self.tiktoken.encoding_for_model(self.path) enc = self.tiktoken.encoding_for_model(self.path)
return len(enc.encode(prompt)) return len(enc.encode(prompt))
def bin_trim(self, prompt: str, num_token: int) -> str:
"""Get a suffix of prompt which is no longer than num_token tokens.
Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.
Returns:
str: The trimmed prompt.
"""
token_len = self.get_token_len(prompt)
if token_len <= num_token:
return prompt
pattern = re.compile(r'[\u4e00-\u9fa5]')
if pattern.search(prompt):
words = list(jieba.cut(prompt, cut_all=False))
else:
words = prompt.split(' ')
l, r = 1, len(words)
while l + 2 < r:
mid = (l + r) // 2
if self.mode == 'front':
cur_prompt = ' '.join(words[-mid:])
elif self.mode == 'mid':
cur_prompt = ' '.join(words[:mid]) + ' '.join(words[-mid:])
elif self.mode == 'rear':
cur_prompt = ' '.join(words[:mid])
if self.get_token_len(cur_prompt) <= num_token:
l = mid # noqa: E741
else:
r = mid
if self.mode == 'front':
prompt = ' '.join(words[-l:])
elif self.mode == 'mid':
prompt = ' '.join(words[:l]) + ' '.join(words[-l:])
elif self.mode == 'rear':
prompt = ' '.join(words[:l])
return prompt
...@@ -7,6 +7,7 @@ datasets>=2.12.0 ...@@ -7,6 +7,7 @@ datasets>=2.12.0
evaluate>=0.3.0 evaluate>=0.3.0
fairscale fairscale
faiss_gpu==1.7.2 faiss_gpu==1.7.2
fuzzywuzzy
jieba jieba
mmengine>=0.8.2 mmengine>=0.8.2
nltk==3.8 nltk==3.8
...@@ -16,6 +17,7 @@ pandas<2.0.0 ...@@ -16,6 +17,7 @@ pandas<2.0.0
rank_bm25==0.2.2 rank_bm25==0.2.2
rapidfuzz rapidfuzz
requests==2.31.0 requests==2.31.0
rouge
rouge_score rouge_score
scikit_learn==1.2.1 scikit_learn==1.2.1
sentence_transformers==2.2.2 sentence_transformers==2.2.2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment