utils.py 4.05 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
2
3
import re
from typing import Union

Hojin Lee's avatar
Hojin Lee committed
4
5
6
7
import evaluate as hf_evaluate


try:
Baber's avatar
Baber committed
8
    compute_ = hf_evaluate.load("code_eval")
Hojin Lee's avatar
Hojin Lee committed
9
10
    test_cases = ["assert add(2, 3)==5"]
    candidates = [["def add(a,b): return a*b"]]
Baber's avatar
Baber committed
11
    results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
Hojin Lee's avatar
Hojin Lee committed
12
13
14
15
except Exception as e:
    raise e


Baber's avatar
Baber committed
16
17
18
def pass_at_10(
    references: list[str], predictions: list[list[str]], k: list[int] = None
):
Baber's avatar
Baber committed
19
20
21
22
    global compute_
    assert k is not None
    if isinstance(k, int):
        k = [k]
Baber's avatar
Baber committed
23
24
25
26
    if isinstance(references, str):
        references = [references]
    if isinstance(predictions[0], str):
        predictions = [[p] for p in predictions]
Baber's avatar
Baber committed
27
    res = compute_.compute(
Hojin Lee's avatar
Hojin Lee committed
28
        references=references,
Baber Abbasi's avatar
Baber Abbasi committed
29
        predictions=predictions,
Baber's avatar
Baber committed
30
        k=k,
Baber's avatar
Baber committed
31
    )
Baber's avatar
Baber committed
32
    return res[0][f"pass@{str(k[0])}"]
Baber's avatar
Baber committed
33
34


Baber's avatar
Baber committed
35
36
37
38
39
40
41
42
43
def extract_python_block(text: str) -> str:
    if not text.startswith("```"):
        text = "```python\n" + text + "\n```"
    # capture only fences whose language tag is 'python'
    pattern = re.compile(r"```python\n([\s\S]*?)\n?```", re.IGNORECASE)
    m = pattern.search(text)
    return "from __future__ import annotations\n" + m.group(1) if m else ""


Baber Abbasi's avatar
Baber Abbasi committed
44
45
def extract_code_blocks(text: str) -> str:
    # Pattern to match ```...``` blocks
Baber's avatar
nit  
Baber committed
46
    ignore_annotations = "from __future__ import annotations\n"
Baber Abbasi's avatar
Baber Abbasi committed
47
48
49
50
51
52
53
54
    pattern = r"```(?:\w+)?\n?(.*?)\n?```"
    # (+ ```) as we add the opening "```python" to the gen_prefix
    matches = re.findall(pattern, r"```" + text, re.DOTALL)
    # if no matches, try to match ```...``` blocks (after removing the language)
    if not matches:
        text_without_lang = re.sub(r"```python", "```", text)
        matches = re.findall(pattern, text_without_lang, re.DOTALL)
    if not matches:
Baber's avatar
fix  
Baber committed
55
        return ignore_annotations + text
Baber Abbasi's avatar
Baber Abbasi committed
56
    else:
Baber's avatar
fix  
Baber committed
57
        return ignore_annotations + matches[0]
Baber Abbasi's avatar
Baber Abbasi committed
58
59


Baber's avatar
Baber committed
60
61
62
63
64
65
66
67
68
69
70
71
def doc_to_text(doc: dict) -> str:
    text = (
        doc["text"]
        + "\n"
        + doc["code"].split(":")[0]
        + ":"
        + "\n"
        + "Here is the completed function:\n\n```python\n"
    )
    return text


Baber Abbasi's avatar
Baber Abbasi committed
72
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
Baber's avatar
Baber committed
73
    return [[extract_python_block(r) for r in resp] for resp in resps]
Baber Abbasi's avatar
Baber Abbasi committed
74
75


Hojin Lee's avatar
Hojin Lee committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def list_fewshot_samples():
    return [
        {
            "task_id": 2,
            "text": "Write a function to find the similar elements from the given two tuple lists.",
            "code": "def similar_elements(test_tup1, test_tup2):\r\n  res = tuple(set(test_tup1) & set(test_tup2))\r\n  return (res) ",
            "test_list": [
                "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)",
                "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)",
                "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)",
            ],
            "is_fewshot": True,
        },
        {
            "task_id": 3,
            "text": "Write a python function to identify non-prime numbers.",
            "code": "import math\r\ndef is_not_prime(n):\r\n    result = False\r\n    for i in range(2,int(math.sqrt(n)) + 1):\r\n        if n % i == 0:\r\n            result = True\r\n    return result",
            "test_list": [
                "assert is_not_prime(2) == False",
                "assert is_not_prime(10) == True",
                "assert is_not_prime(35) == True",
            ],
            "is_fewshot": True,
        },
        {
            "task_id": 4,
            "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.",
            "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n  largest_nums = hq.nlargest(n, nums)\r\n  return largest_nums",
            "test_list": [
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ",
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ",
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]",
            ],
            "is_fewshot": True,
        },
    ]