ggml.py 4.5 KB
Newer Older
Matt Hoffner's avatar
Matt Hoffner committed
1
import requests
2
import logging
Matt Hoffner's avatar
Matt Hoffner committed
3
import time
Matt Hoffner's avatar
Matt Hoffner committed
4
5
6
from tqdm import tqdm
from requests.exceptions import RequestException

Matt Hoffner's avatar
Matt Hoffner committed
7
8
9
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM

10
11
logger = logging.getLogger(__name__)

Matt Hoffner's avatar
Matt Hoffner committed
12
13
14
15
16
17
18
19
20
21
22
def ggml_completion(base_url, retries=3, delay=5, **kwargs):
    for _ in range(retries):
        try:
            response = requests.post(f"{base_url}/v1/completions", json=kwargs)
            response.raise_for_status()
            return response.json()
        except RequestException as e:
            logger.error(f"RequestException: {e}")
            time.sleep(delay)  # wait before retrying
    else:
        raise Exception(f"Failed to get a valid response after {retries} retries. Last exception: {e}")
Matt Hoffner's avatar
Matt Hoffner committed
23

Matt Hoffner's avatar
Matt Hoffner committed
24
class GGMLLM(BaseLM):
Matt Hoffner's avatar
Matt Hoffner committed
25
26
27
28
29
30
    def __init__(self, base_url, truncate=False):
        super().__init__()
        self.base_url = base_url
        self.truncate = truncate

    def loglikelihood(self, requests):
Matt Hoffner's avatar
Matt Hoffner committed
31
32
33
        reorderer = Reorderer(requests, len)
        requests = reorderer.get_reordered()

Matt Hoffner's avatar
Matt Hoffner committed
34
35
        res = []
        for context, continuation in tqdm(requests):
Matt Hoffner's avatar
Matt Hoffner committed
36
            response = ggml_completion(self.base_url, context=context, continuation=continuation)
37
38
39
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                logprobs = choice.get("logprobs")
Matt Hoffner's avatar
Matt Hoffner committed
40
41
42
43
                try:
                    logprob = logprobs["token_logprobs"][0]
                except TypeError:
                    raise ValueError("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
44
                is_greedy = choice["finish_reason"] == "length"
Matt Hoffner's avatar
Matt Hoffner committed
45
46
                res.append((logprob, is_greedy))
            else:
47
                logger.error(f"Invalid response for loglikelihood. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
48
                assert False
Matt Hoffner's avatar
Matt Hoffner committed
49
        return reorderer.get_original(res)
Matt Hoffner's avatar
Matt Hoffner committed
50
51
52
53
54

    def greedy_until(self, requests):
        if not requests:
            return []

Matt Hoffner's avatar
Matt Hoffner committed
55
56
57
        reorderer = Reorderer(requests, len)
        requests = reorderer.get_reordered()

Matt Hoffner's avatar
Matt Hoffner committed
58
59
60
61
62
        res = []
        for request in tqdm(requests):
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
Matt Hoffner's avatar
Matt Hoffner committed
63
            response = ggml_completion(self.base_url, context=inp, stop=until)
Matt Hoffner's avatar
Matt Hoffner committed
64
65
66
67
68
69
70
71
72
            print(response);
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                if "text" in choice:
                    generated_text = choice["text"].strip()
                    res.append(generated_text)
                else:
                    logger.error(f"Invalid response for greedy_until. Response: {response}")
                    res.append(None)  # Add default value in case of error
Matt Hoffner's avatar
Matt Hoffner committed
73
            else:
74
                logger.error(f"Invalid response for greedy_until. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
75
76
77
                res.append(None)  # Add default value in case of error
        return reorderer.get_original(res)

Matt Hoffner's avatar
Matt Hoffner committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def loglikelihood_rolling(self, requests):
        results = []

        for request in requests:
            logprobs = []
            for i in range(0, len(request), self.max_length):
                chunk = request[i:i+self.max_length]
                chunk_loglikelihood = self.loglikelihood([(chunk, request[i+1:i+self.max_length+1])])
                logprobs.extend(chunk_loglikelihood)
            
            avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
            results.append((avg_loglikelihood, True))

        return results
Matt Hoffner's avatar
Matt Hoffner committed
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    
    def _model_call(self, inps):
        # Placeholder implementation
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def batch_size(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def device(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def eot_token_id(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
Matt Hoffner's avatar
Matt Hoffner committed
118
119
    def max_length(self):        
        return 1024
120
121
122
123
124
125
126
127
128
129
130
131
132

    @property
    def max_gen_toks(self):
        # Placeholder implementation
        raise NotImplementedError()

    def tok_encode(self, string: str):
        # Placeholder implementation
        raise NotImplementedError()

    def tok_decode(self, tokens):
        # Placeholder implementation
        raise NotImplementedError()