Commit 83827fd9 authored by Baber's avatar Baber
Browse files

add mbpp

parent 2507c434
task: mbpp_evalplus
dataset_path: google-research-datasets/mbpp
dataset_name: full
unsafe_code: true
output_type: generate_until
test_split: test
repeats: 20
#doc_to_text: "You are an expert Python programmer, and here is your task: {{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}\n[BEGIN]\n"
doc_to_text: |
Please provide a self-contained Python script that solves the following problem in a markdown code block:
```
{{text|trim}}
{{test_list|random}}
```
doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}"
target_delimiter: ""
gen_prefix: "Here is the completed function:\n\n```python\n"
metric_list:
- metric: !function utils.pass_at_10
aggregation: mean
higher_is_better: true
filter_list:
- name: "create_test"
filter:
- function: "custom"
filter_fn: !function utils.build_predictions
generation_kwargs:
until: [
"\nclass",
"\nassert",
'\n"""',
"\nprint",
"\nif",
"\n```",
"\n#",
"\n<|/",
"<|eot_id|>",
]
do_sample: true
temperature: 0.8
top_p: 0.95
max_gen_toks: 2
num_fewshot: 0
fewshot_config:
sampler: first_n
samples: !function utils.list_fewshot_samples
metadata:
version: 1.0
...@@ -29,6 +29,22 @@ def pass_at_1( ...@@ -29,6 +29,22 @@ def pass_at_1(
)[0]["pass@1"] )[0]["pass@1"]
def pass_at_10(
references: Union[str, list[str]], predictions: Union[str, list[list[str]]]
) -> float:
if isinstance(references, str):
references = [references]
if isinstance(predictions[0], str):
predictions = [[p] for p in predictions]
print("references: ", references)
print("predictions: ", predictions)
pass_at_k = hf_evaluate.load("code_eval")
res = pass_at_k.compute(
references=references, predictions=predictions, k=[10], num_workers=20
)
return res[0]
def extract_code_blocks(text: str) -> str: def extract_code_blocks(text: str) -> str:
# Pattern to match ```...``` blocks # Pattern to match ```...``` blocks
pattern = r"```(?:\w+)?\n?(.*?)\n?```" pattern = r"```(?:\w+)?\n?(.*?)\n?```"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment