utils.py 3.64 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
2
3
import re
from typing import Union

Hojin Lee's avatar
Hojin Lee committed
4
5
6
7
import evaluate as hf_evaluate


try:
Baber's avatar
Baber committed
8
    compute_ = hf_evaluate.load("code_eval")
Hojin Lee's avatar
Hojin Lee committed
9
10
    test_cases = ["assert add(2, 3)==5"]
    candidates = [["def add(a,b): return a*b"]]
Baber's avatar
Baber committed
11
    results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
Hojin Lee's avatar
Hojin Lee committed
12
13
14
15
except Exception as e:
    raise e


Baber's avatar
Baber committed
16
17
18
19
20
21
def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
    global compute_
    assert k is not None
    if isinstance(k, int):
        k = [k]
    res = compute_.compute(
Hojin Lee's avatar
Hojin Lee committed
22
        references=references,
Baber Abbasi's avatar
Baber Abbasi committed
23
        predictions=predictions,
Baber's avatar
Baber committed
24
        k=k,
Baber's avatar
Baber committed
25
26
27
28
    )
    return res[0]


Baber's avatar
Baber committed
29
30
31
32
33
34
35
36
37
def extract_python_block(text: str) -> str:
    if not text.startswith("```"):
        text = "```python\n" + text + "\n```"
    # capture only fences whose language tag is 'python'
    pattern = re.compile(r"```python\n([\s\S]*?)\n?```", re.IGNORECASE)
    m = pattern.search(text)
    return "from __future__ import annotations\n" + m.group(1) if m else ""


Baber Abbasi's avatar
Baber Abbasi committed
38
39
def extract_code_blocks(text: str) -> str:
    # Pattern to match ```...``` blocks
Baber's avatar
nit  
Baber committed
40
    ignore_annotations = "from __future__ import annotations\n"
Baber Abbasi's avatar
Baber Abbasi committed
41
42
43
44
45
46
47
48
    pattern = r"```(?:\w+)?\n?(.*?)\n?```"
    # (+ ```) as we add the opening "```python" to the gen_prefix
    matches = re.findall(pattern, r"```" + text, re.DOTALL)
    # if no matches, try to match ```...``` blocks (after removing the language)
    if not matches:
        text_without_lang = re.sub(r"```python", "```", text)
        matches = re.findall(pattern, text_without_lang, re.DOTALL)
    if not matches:
Baber's avatar
fix  
Baber committed
49
        return ignore_annotations + text
Baber Abbasi's avatar
Baber Abbasi committed
50
    else:
Baber's avatar
fix  
Baber committed
51
        return ignore_annotations + matches[0]
Baber Abbasi's avatar
Baber Abbasi committed
52
53
54
55
56
57


def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
    return [[extract_code_blocks(r) for r in resp] for resp in resps]


Hojin Lee's avatar
Hojin Lee committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def list_fewshot_samples():
    return [
        {
            "task_id": 2,
            "text": "Write a function to find the similar elements from the given two tuple lists.",
            "code": "def similar_elements(test_tup1, test_tup2):\r\n  res = tuple(set(test_tup1) & set(test_tup2))\r\n  return (res) ",
            "test_list": [
                "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)",
                "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)",
                "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)",
            ],
            "is_fewshot": True,
        },
        {
            "task_id": 3,
            "text": "Write a python function to identify non-prime numbers.",
            "code": "import math\r\ndef is_not_prime(n):\r\n    result = False\r\n    for i in range(2,int(math.sqrt(n)) + 1):\r\n        if n % i == 0:\r\n            result = True\r\n    return result",
            "test_list": [
                "assert is_not_prime(2) == False",
                "assert is_not_prime(10) == True",
                "assert is_not_prime(35) == True",
            ],
            "is_fewshot": True,
        },
        {
            "task_id": 4,
            "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.",
            "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n  largest_nums = hq.nlargest(n, nums)\r\n  return largest_nums",
            "test_list": [
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ",
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ",
                "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]",
            ],
            "is_fewshot": True,
        },
    ]