Unverified Commit 444d8d95 authored by Connor-Shen's avatar Connor-Shen Committed by GitHub
Browse files

[feat] support multipl-e (#846)



* [feat] support humaneval_multipl-e

* format

---------
Co-authored-by: default avatarLeymore <zfz-960727@163.com>
parent a6c49f15
from mmengine.config import read_base
with read_base():
from .humaneval_multi_gen_82cf85 import humaneval_multi_datasets # noqa: F401, F403
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HumanevalMultiDataset, HumanevalMultiEvaluator
humaneval_multi_reader_cfg = dict(input_columns=['prompt'], output_column='tests')
humaneval_multi_infer_cfg = dict(
prompt_template=dict(type=PromptTemplate, template='{prompt}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=1024),
)
humaneval_multi_eval_cfg = {
lang: dict(
evaluator=dict(
type=HumanevalMultiEvaluator,
language=lang,
ip_address='localhost', # replace to your code_eval_server ip_address, port
port=5000,
), # refer to https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html to launch a server
pred_role='BOT',
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
}
'''there are four versions of humaneval-{LANG}-{version}.jsonl:
['keep', 'transform', 'reworded', 'remove']
SRCDATA-LANG-keep is the same as SRCDATA-LANG, but the text of the prompt is totally unchanged. If the original prompt had Python doctests, they remain as Python instead of being translated to LANG. If the original prompt had Python-specific terminology, e.g., 'list', it remains 'list', instead of being translated, e.g., to 'vector' for C++.
SRCDATA-LANG-transform transforms the doctests to LANG but leaves the natural language text of the prompt unchanged.
SRCDATA-LANG-reworded transforms both the doctests and the natural language text of the prompt to LANG.
SRCDATA-LANG-remove removes the doctests from the prompt.
'''
humaneval_multi_datasets = [
dict(
type=HumanevalMultiDataset,
abbr=f'humaneval_multiple-{lang}',
language=lang,
version='reworded', # choose from ['keep', 'transform', 'reworded', 'remove']
num_repeats=1,
path='./data/multi-data/humaneval_multipl-e/',
reader_cfg=humaneval_multi_reader_cfg,
infer_cfg=humaneval_multi_infer_cfg,
eval_cfg=humaneval_multi_eval_cfg[lang],
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
]
import gzip
import json
import os
import os.path as osp
import re
import shutil
import subprocess
import tempfile
import time
import numpy as np
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
# currently supporting 19 languages
_LANGUAGE_NAME_DICT = {
'cpp': 'CPP',
'cs': 'C#',
'd': 'D',
'go': 'Go',
'java': 'Java',
'jl': 'Julia',
'js': 'JavaScript',
'lua': 'Lua',
'php': 'PHP',
'pl': 'Perl',
'py': 'Python',
'r': 'R',
'rb': 'Ruby',
'rkt': 'Racket',
'rs': 'Rust',
'scala': 'Scala',
'sh': 'Shell',
'swift': 'Swift',
'ts': 'TypeScript',
}
@LOAD_DATASET.register_module()
class HumanevalMultiDataset(BaseDataset):
@staticmethod
def load(path, language, version, num_repeats: int = 1, **kwargs):
"""Load humaneval dataset for pass k mode.
Note that you can use num_repeats > 1 when your model does not support
`num_return_sequence` in generation, otherwise use the raw
humaneval dataset and set `num_return_sequence` in model config to
generate multiple responses for testing pass@k>1.
It better to change your dataset abbr correspondingly if you want to
change num_repeats>1, otherwise the number in
`.cache/dataset_size.json` might be inconsistent.
Args:
num_repeats(int): Number of repetition for this dataset to get
multiple responses in special cases.
"""
assert language in _LANGUAGE_NAME_DICT.keys(), (
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
assert version in [
'keep', 'transform', 'reworded', 'remove'
], ('version must be in ["keep", "transform", "reworded", "remove"]')
file_path = osp.join(path, f'humaneval-{language}-{version}.jsonl')
dataset = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
dataset.extend(
[json.loads(line.strip()) for _ in range(num_repeats)])
return Dataset.from_list(dataset)
class HumanevalMultiEvaluator(BaseEvaluator):
def __init__(self,
language,
ip_address='localhost',
port=5000,
retry=2,
timeout=600) -> None:
self.language = language
self.ip_address = ip_address
self.port = port
self.retry = retry
self.timeout = timeout
super().__init__()
def stop_at_stop_token(self, decoded_string, stop_tokens):
"""Produces the prefix of decoded_string that ends at the first
occurrence of a stop_token.
WARNING: the decoded_string *must not* include the prompt,
which may have stop tokens itself.
"""
min_stop_index = len(decoded_string)
for stop_token in stop_tokens:
stop_index = decoded_string.find(stop_token)
if stop_index != -1 and stop_index < min_stop_index:
min_stop_index = stop_index
return decoded_string[:min_stop_index]
def _code_eval_service(self, file_path):
exec_result = subprocess.run([
'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
f'dataset=multipl-e/{self.language}',
f'{self.ip_address}:{self.port}/evaluate'
],
timeout=self.timeout,
capture_output=True)
if exec_result.returncode == 0 and re.match(
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
return True, json.loads(exec_result.stdout.decode('utf-8'))
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except Exception:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except Exception:
err = exec_result.stdout
return False, err
def estimator(self, n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def for_file(self, path):
try:
with gzip.open(path, 'rt') as f:
data = json.load(f)
except Exception:
return None
n = len(data['results'])
c = len([
True for r in data['results']
if r['status'] == 'OK' and r['exit_code'] == 0
])
return {
'pass@1': self.estimator(n, c, 1),
'pass@10': self.estimator(n, c, 10),
'pass@100': self.estimator(n, c, 100),
'n': n,
'c': c,
}
def score(self, predictions, references, test_set):
stop_tokens = test_set['stop_tokens'][0]
print(stop_tokens)
# convert to original version
test_set = test_set.to_pandas()
test_set_origin = test_set.drop_duplicates(subset='name')
num_repeats = int(len(test_set) / len(test_set_origin))
print(num_repeats)
# Create a temporary directory using the tempfile module
with tempfile.TemporaryDirectory() as tmpdir:
for i in range(len(test_set_origin)):
completions = predictions[i * num_repeats:(i + 1) *
num_repeats]
processed_completions = []
for comp in completions:
comp = self.stop_at_stop_token(comp, stop_tokens)
processed_completions.append(comp)
result_dict = {
'name': test_set_origin.iloc[i]['name'],
'language': test_set_origin.iloc[i]['language'],
'prompt': test_set_origin.iloc[i]['prompt'],
'tests': test_set_origin.iloc[i]['tests'],
'completions': processed_completions
}
json_str = json.dumps(result_dict)
json_bytes = json_str.encode('utf-8')
with gzip.GzipFile(
os.path.join(tmpdir, f'{result_dict["name"]}.json.gz'),
'w') as f:
f.write(json_bytes)
# create a zip file containing all the generated .json.gz files
zipname = os.path.join(tmpdir, 'archive')
shutil.make_archive(zipname, 'zip', tmpdir)
zipfile_path = f'{zipname}.zip'
num_retry = 0
while num_retry < self.retry:
succeed, output = self._code_eval_service(
file_path=zipfile_path)
if not succeed and '(56) Recv failure' in output:
# only retry when connection failed
num_retry += 1
# wait a min in case the service load is too high
time.sleep(60)
else:
break
if succeed:
if isinstance(output, str):
return json.loads(output)
elif isinstance(output, dict):
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment