gguf.py 4.95 KB
Newer Older
Matt Hoffner's avatar
Matt Hoffner committed
1
import requests
2
import logging
Matt Hoffner's avatar
Matt Hoffner committed
3
import time
Matt Hoffner's avatar
Matt Hoffner committed
4
5
from tqdm import tqdm
from requests.exceptions import RequestException
6
import transformers
Matt Hoffner's avatar
Matt Hoffner committed
7
8
9
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM

10
11
logger = logging.getLogger(__name__)

Matt Hoffner's avatar
Matt Hoffner committed
12

13
def get_result(logprobs, context_length):
Lorenzo's avatar
Lorenzo committed
14
15
16
17
18
19
    is_greedy = True
    offsets = logprobs['text_offset']
    tokens = logprobs['tokens']
    tokens_logprobs = logprobs['token_logprobs']

    idx = 0
20
    while offsets[idx] < context_length:
Lorenzo's avatar
Lorenzo committed
21
22
23
24
25
26
27
28
29
30
31
32
33
        idx += 1
    continuation_logprobs = sum(tokens_logprobs[idx:-1])
    for i in range(idx, len(tokens)):
        token = tokens[i]
        top_tokens = logprobs["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break

    return continuation_logprobs, is_greedy


34
class GGUFLM(BaseLM):
35
    def __init__(self, base_url, max_length=2048):
Matt Hoffner's avatar
Matt Hoffner committed
36
37
38
        super().__init__()
        self.base_url = base_url
        self.truncate = truncate
39
        self.logprobs = 10
Lorenzo's avatar
Lorenzo committed
40
        self.temperature = 0.0
41
        self.max_length = max_length
42

43
    def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
44
45
46
        for _ in range(retries):
            try:
                prompt = context
47
                request = {'prompt': prompt, 'logprobs': self.logprobs,
Lorenzo's avatar
Lorenzo committed
48
                           'temperature': self.temperature}
49
50
                if continuation:
                    prompt += continuation
Lorenzo's avatar
Lorenzo committed
51
                    request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
52
53
                if stop is not None:
                    request['stop'] = stop
54
                response = requests.post(f"{self.base_url}/v1/completions", json=request)
55
56
57
58
59
60
                response.raise_for_status()
                return response.json()
            except RequestException as e:
                logger.error(f"RequestException: {e}")
                time.sleep(delay)  # wait before retrying
        else:
61
            raise Exception(f"Failed to get a valid response after {retries} retries.")
62

Matt Hoffner's avatar
Matt Hoffner committed
63
    def loglikelihood(self, requests):
64
65
        if not requests:
            return []
Matt Hoffner's avatar
Matt Hoffner committed
66
67
        res = []
        for context, continuation in tqdm(requests):
68
            response = self.gguf_completion(context=context, continuation=continuation)
69
70
71
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                logprobs = choice.get("logprobs")
72
                if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
Lorenzo's avatar
Lorenzo committed
73
                    logprob, is_greedy = get_result(logprobs, len(context))
74
75
76
                    res.append((logprob, is_greedy))
                else:
                    logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
Matt Hoffner's avatar
Matt Hoffner committed
77
            else:
78
                logger.error(f"Invalid response for loglikelihood. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
79
                assert False
80
        return res
Matt Hoffner's avatar
Matt Hoffner committed
81
82
83
84
85
86
87
88
89
90

    def greedy_until(self, requests):
        if not requests:
            return []

        res = []
        for request in tqdm(requests):
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
91
            response = self.gguf_completion(context=inp, stop=until)
Matt Hoffner's avatar
Matt Hoffner committed
92
93
94
95
96
97
98
99
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                if "text" in choice:
                    generated_text = choice["text"].strip()
                    res.append(generated_text)
                else:
                    logger.error(f"Invalid response for greedy_until. Response: {response}")
                    res.append(None)  # Add default value in case of error
Matt Hoffner's avatar
Matt Hoffner committed
100
            else:
101
                logger.error(f"Invalid response for greedy_until. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
102
                res.append(None)  # Add default value in case of error
103
        return res
Matt Hoffner's avatar
Matt Hoffner committed
104

Matt Hoffner's avatar
Matt Hoffner committed
105
    def loglikelihood_rolling(self, requests):
106
        raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
Matt Hoffner's avatar
Matt Hoffner committed
107

108
109
110
111
112
113
114
115
    def _model_call(self, inps):
        # Placeholder implementation
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Placeholder implementation
        raise NotImplementedError()

116
    def tok_encode(self, string: str):
117
        raise NotImplementedError()
118
119

    def tok_decode(self, tokens):
120
        raise NotImplementedError()
Lorenzo's avatar
Lorenzo committed
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    @property
    def batch_size(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def device(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def eot_token_id(self):
        # Placeholder implementation
        raise NotImplementedError()

Lorenzo's avatar
Lorenzo committed
137
    def max_length(self):
138
        return self.max_length
139
140
141
142

    @property
    def max_gen_toks(self):
        # Placeholder implementation
Lorenzo's avatar
Lorenzo committed
143
        raise NotImplementedError()