Unverified Commit 60e85da5 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

add Mbpp instruct (#2995)

* feat: add mbpp_instruct

* fix: update generation_kwargs to use an empty until list

* fix: correct predictions formatting in pass_at_1 function

* fix: improve code block extraction by checking first without opening backticks

* fix mbpp `pass_at_1`
parent d57e3d65
task: mbpp_instruct
dataset_path: google-research-datasets/mbpp
dataset_name: full
unsafe_code: true
output_type: generate_until
test_split: test
doc_to_text: "You are an expert Python programmer, and here is your task:\n{{text}}\nYour code should pass these tests:\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}"
doc_to_target: "{% if is_fewshot is defined %}{{code}}\n```{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}"
gen_prefix: "\n```python\n"
target_delimiter: ""
metric_list:
- metric: !function utils.pass_at_1
aggregation: mean
higher_is_better: true
filter_list:
- name: "extract_code"
filter:
- function: "custom"
filter_fn: !function utils.build_predictions
generation_kwargs:
max_gen_toks: 256
until: []
do_sample: false
num_fewshot: 3
fewshot_config:
sampler: first_n
samples: !function utils.list_fewshot_samples
metadata:
version: 1.0
include: mbpp_instruct.yaml
task: mbpp_plus_instruct
dataset_path: evalplus/mbppplus
dataset_name: null
doc_to_text: "{{prompt if prompt is defined else text}} Your code should satisfy the following assertion:\n{{test_list[0]}}"
doc_to_target: "{{test_list[0]}}"
gen_prefix: "Here is a solution to this programming problem:\n```python\n"
num_fewshot: 0
generation_kwargs:
max_gen_toks: 1024
until: []
do_sample: false
import re
from typing import Union
import evaluate as hf_evaluate import evaluate as hf_evaluate
...@@ -12,14 +15,41 @@ except Exception as e: ...@@ -12,14 +15,41 @@ except Exception as e:
raise e raise e
def pass_at_1(references, predictions): def pass_at_1(
references: Union[str, list[str]], predictions: Union[str, list[list[str]]]
) -> float:
if isinstance(references, str):
references = [references]
if isinstance(predictions[0], str):
predictions = [[p] for p in predictions]
print(f"References: {references}")
print(f"Predictions: {predictions}")
return pass_at_k.compute( return pass_at_k.compute(
references=references, references=references,
predictions=[predictions], predictions=predictions,
k=[1], k=[1],
)[0]["pass@1"] )[0]["pass@1"]
def extract_code_blocks(text: str) -> str:
# Pattern to match ```...``` blocks
pattern = r"```(?:\w+)?\n?(.*?)\n?```"
# (+ ```) as we add the opening "```python" to the gen_prefix
matches = re.findall(pattern, r"```" + text, re.DOTALL)
# if no matches, try to match ```...``` blocks (after removing the language)
if not matches:
text_without_lang = re.sub(r"```python", "```", text)
matches = re.findall(pattern, text_without_lang, re.DOTALL)
if not matches:
return ""
else:
return matches[0]
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
return [[extract_code_blocks(r) for r in resp] for resp in resps]
def list_fewshot_samples(): def list_fewshot_samples():
return [ return [
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment