utils.py 4.13 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
2
3
import re
from typing import Union

Hojin Lee's avatar
Hojin Lee committed
4
5
6
7
import evaluate as hf_evaluate


try:
Baber's avatar
Baber committed
8
    compute_ = hf_evaluate.load("code_eval")
Hojin Lee's avatar
Hojin Lee committed
9
10
    test_cases = ["assert add(2, 3)==5"]
    candidates = [["def add(a,b): return a*b"]]
Baber's avatar
Baber committed
11
    results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
Hojin Lee's avatar
Hojin Lee committed
12
13
14
15
except Exception as e:
    raise e


Baber's avatar
Baber committed
16
17
18
def pass_at_10(
    references: list[str], predictions: list[list[str]], k: list[int] = None
):
Baber's avatar
Baber committed
19
20
21
22
    global compute_
    assert k is not None
    if isinstance(k, int):
        k = [k]
Baber's avatar
Baber committed
23
24
25
26
27
28
29
    if isinstance(references, str):
        references = [references]
    if isinstance(predictions[0], str):
        predictions = [[p] for p in predictions]
    print(f"{references=}")
    print(f"{predictions=}")
    print(f"{k=}")
Baber's avatar
Baber committed
30
    res = compute_.compute(
Hojin Lee's avatar
Hojin Lee committed
31
        references=references,
Baber Abbasi's avatar
Baber Abbasi committed
32
        predictions=predictions,
Baber's avatar
Baber committed
33
        k=k,
Baber's avatar
Baber committed
34
    )
Baber's avatar
Baber committed
35
    return res[0][f"pass@{str(k[0])}"]
Baber's avatar
Baber committed
36
37


Baber's avatar
Baber committed
38
39
40
41
42
43
44
45
46
def extract_python_block(text: str) -> str:
    if not text.startswith("```"):
        text = "```python\n" + text + "\n```"
    # capture only fences whose language tag is 'python'
    pattern = re.compile(r"```python\n([\s\S]*?)\n?```", re.IGNORECASE)
    m = pattern.search(text)
    return "from __future__ import annotations\n" + m.group(1) if m else ""


Baber Abbasi's avatar
Baber Abbasi committed
47
48
def extract_code_blocks(text: str) -> str:
    # Pattern to match ```...``` blocks
Baber's avatar
nit  
Baber committed
49
    ignore_annotations = "from __future__ import annotations\n"
Baber Abbasi's avatar
Baber Abbasi committed
50
51
52
53
54
55
56
57
    pattern = r"```(?:\w+)?\n?(.*?)\n?```"
    # (+ ```) as we add the opening "```python" to the gen_prefix
    matches = re.findall(pattern, r"```" + text, re.DOTALL)
    # if no matches, try to match ```...``` blocks (after removing the language)
    if not matches:
        text_without_lang = re.sub(r"```python", "```", text)
        matches = re.findall(pattern, text_without_lang, re.DOTALL)
    if not matches:
Baber's avatar
fix  
Baber committed
58
        return ignore_annotations + text
Baber Abbasi's avatar
Baber Abbasi committed
59
    else:
Baber's avatar
fix  
Baber committed
60
        return ignore_annotations + matches[0]
Baber Abbasi's avatar
Baber Abbasi committed
61
62


Baber's avatar
Baber committed
63
64
65
66
67
68
69
70
71
72
73
74
def doc_to_text(doc: dict) -> str:
    text = (
        doc["text"]
        + "\n"
        + doc["code"].split(":")[0]
        + ":"
        + "\n"
        + "Here is the completed function:\n\n```python\n"
    )
    return text


Baber Abbasi's avatar
Baber Abbasi committed
75
def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
Baber's avatar
Baber committed
76
    return [[extract_python_block(r) for r in resp] for resp in resps]
Baber Abbasi's avatar
Baber Abbasi committed
77
78


Hojin Lee's avatar
Hojin Lee committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def list_fewshot_samples():
    return [
        {
            "task_id": 2,
            "text": "Write a function to find the similar elements from the given two tuple lists.",
            "code": "def similar_elements(test_tup1, test_tup2):\r\n  res = tuple(set(test_tup1) & set(test_tup2))\r\n  return (res) ",
            "test_list": [
                "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)",
                "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)",
                "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)",
            ],
            "is_fewshot": True,
        },
        {
            "task_id": 3,
            "text": "Write a python function to identify non-prime numbers.",
            "code": "import math\r\ndef is_not_prime(n):\r\n    result = False\r\n    for i in range(2,int(math.sqrt(n)) + 1):\r\n        if n % i == 0:\r\n            result = True\r\n    return result",
            "test_list": [
                "assert is_not_prime(2) == False",
                "assert is_not_prime(10) == True",
                "assert is_not_prime(35) == True",
            ],
            "is_fewshot": True,
        },
        {
            "task_id": 4,
            "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.",
            "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n  largest_nums = hq.nlargest(n, nums)\r\n  return largest_nums",
            "test_list": [
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ",
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ",
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]",
            ],
            "is_fewshot": True,
        },
    ]