Commit 9d37bcff authored by Baber's avatar Baber
Browse files

fix mbpp

parent 3e8135ce
......@@ -4,25 +4,39 @@ dataset_name: full
unsafe_code: true
output_type: generate_until
test_split: test
doc_to_text: "You are an expert Python programmer, and here is your task:\n{{text}}\nYour code should pass these tests:\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}"
doc_to_target: "{% if is_fewshot is defined %}{{code}}\n```{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}"
doc_to_text: !function utils.doc_to_text
doc_to_target: "{{test_list[1]}}\n{{test_list[2]}}"
gen_prefix: "\n```python\n"
target_delimiter: ""
metric_list:
- metric: !function utils.pass_at_k
- metric: !function utils.pass_at_10
aggregation: mean
higher_is_better: true
k: [ 1 ]
k: [ 10 ]
filter_list:
- name: "extract_code"
filter:
- function: "custom"
filter_fn: !function utils.build_predictions
repeats: 20
generation_kwargs:
max_gen_toks: 256
until: [ ]
do_sample: false
num_fewshot: 3
until: [
"\nclass",
"\nassert",
'\n"""',
"\nprint",
"\nif",
"\n```",
"\n#",
"\n<|/",
"<|eot_id|>",
]
do_sample: true
temperature: 0.8
top_p: 0.95
num_fewshot: 0
fewshot_config:
sampler: first_n
samples: !function utils.list_fewshot_samples
......
......@@ -13,17 +13,26 @@ except Exception as e:
raise e
def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
def pass_at_10(
references: list[str], predictions: list[list[str]], k: list[int] = None
):
global compute_
assert k is not None
if isinstance(k, int):
k = [k]
if isinstance(references, str):
references = [references]
if isinstance(predictions[0], str):
predictions = [[p] for p in predictions]
print(f"{references=}")
print(f"{predictions=}")
print(f"{k=}")
res = compute_.compute(
references=references,
predictions=predictions,
k=k,
)
return res[0]
return res[0][f"pass@{str(k[0])}"]
def extract_python_block(text: str) -> str:
......@@ -51,8 +60,20 @@ def extract_code_blocks(text: str) -> str:
return ignore_annotations + matches[0]
def doc_to_text(doc: dict) -> str:
text = (
doc["text"]
+ "\n"
+ doc["code"].split(":")[0]
+ ":"
+ "\n"
+ "Here is the completed function:\n\n```python\n"
)
return text
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
return [[extract_code_blocks(r) for r in resp] for resp in resps]
return [[extract_python_block(r) for r in resp] for resp in resps]
def list_fewshot_samples():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment