Unverified Commit 655a807f authored by philipwangOvO's avatar philipwangOvO Committed by GitHub
Browse files

[Dataset] LongBench (#236)


Co-authored-by: default avatarwangchonghua <wangchonghua@pjlab.org.cn>
parent c6a34949
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchrepobenchDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
question = dataset[split]['input'][i]
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
raw_data.append({
'input': question,
'context': context,
'answers': answers
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchtrecDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
question = dataset[split]['input'][i]
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
all_classes = dataset[split]['all_classes'][i]
raw_data.append({
'input': question,
'context': context,
'all_labels': {
'answers': answers,
'all_classes': all_classes
}
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchtriviaqaDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
question = dataset[split]['input'][i]
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
raw_data.append({
'input': question,
'context': context,
'answers': answers
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from ..base import BaseDataset
@LOAD_DATASET.register_module()
class LongBenchvcsumDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
split = 'test'
raw_data = []
for i in range(len(dataset[split])):
context = dataset[split]['context'][i]
answers = dataset[split]['answers'][i]
raw_data.append({'context': context, 'answers': answers})
dataset[split] = Dataset.from_list(raw_data)
return dataset
import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from typing import Dict, List, Optional, Union
import jieba
import requests
from opencompass.registry import MODELS
......@@ -42,6 +44,9 @@ class OpenAI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'front','mid' and 'rear' represents the part
of input to truncate. Defaults to 'none'.
temperature (float, optional): What sampling temperature to use.
If not None, will override the temperature in the `generate()`
call. Defaults to None.
......@@ -58,6 +63,7 @@ class OpenAI(BaseAPIModel):
org: Optional[Union[str, List[str]]] = None,
meta_template: Optional[Dict] = None,
openai_api_base: str = OPENAI_API_BASE,
mode: str = 'none',
temperature: Optional[float] = None):
super().__init__(path=path,
......@@ -68,6 +74,8 @@ class OpenAI(BaseAPIModel):
import tiktoken
self.tiktoken = tiktoken
self.temperature = temperature
assert mode in ['none', 'front', 'mid', 'rear']
self.mode = mode
if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
......@@ -137,6 +145,20 @@ class OpenAI(BaseAPIModel):
"""
assert isinstance(input, (str, PromptList))
# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.path:
context_window = 32768
elif '16k' in self.path:
context_window = 16384
elif 'gpt-4' in self.path:
context_window = 8192
# will leave 100 tokens as prompt buffer, triggered if input is str
if isinstance(input, str) and self.mode != 'none':
context_window = self.max_seq_len
input = self.bin_trim(input, context_window - 100 - max_out_len)
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
......@@ -151,15 +173,6 @@ class OpenAI(BaseAPIModel):
msg['role'] = 'system'
messages.append(msg)
# max num token for gpt-3.5-turbo is 4097
context_window = 4096
if '32k' in self.path:
context_window = 32768
elif '16k' in self.path:
context_window = 16384
elif 'gpt-4' in self.path:
context_window = 8192
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len = min(
max_out_len, context_window - self.get_token_len(str(input)) - 100)
......@@ -251,3 +264,45 @@ class OpenAI(BaseAPIModel):
"""
enc = self.tiktoken.encoding_for_model(self.path)
return len(enc.encode(prompt))
def bin_trim(self, prompt: str, num_token: int) -> str:
"""Get a suffix of prompt which is no longer than num_token tokens.
Args:
prompt (str): Input string.
num_token (int): The upper bound of token numbers.
Returns:
str: The trimmed prompt.
"""
token_len = self.get_token_len(prompt)
if token_len <= num_token:
return prompt
pattern = re.compile(r'[\u4e00-\u9fa5]')
if pattern.search(prompt):
words = list(jieba.cut(prompt, cut_all=False))
else:
words = prompt.split(' ')
l, r = 1, len(words)
while l + 2 < r:
mid = (l + r) // 2
if self.mode == 'front':
cur_prompt = ' '.join(words[-mid:])
elif self.mode == 'mid':
cur_prompt = ' '.join(words[:mid]) + ' '.join(words[-mid:])
elif self.mode == 'rear':
cur_prompt = ' '.join(words[:mid])
if self.get_token_len(cur_prompt) <= num_token:
l = mid # noqa: E741
else:
r = mid
if self.mode == 'front':
prompt = ' '.join(words[-l:])
elif self.mode == 'mid':
prompt = ' '.join(words[:l]) + ' '.join(words[-l:])
elif self.mode == 'rear':
prompt = ' '.join(words[:l])
return prompt
......@@ -7,6 +7,7 @@ datasets>=2.12.0
evaluate>=0.3.0
fairscale
faiss_gpu==1.7.2
fuzzywuzzy
jieba
mmengine>=0.8.2
nltk==3.8
......@@ -16,6 +17,7 @@ pandas<2.0.0
rank_bm25==0.2.2
rapidfuzz
requests==2.31.0
rouge
rouge_score
scikit_learn==1.2.1
sentence_transformers==2.2.2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment