gguf.py 4.67 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
import logging
import time
Rayyyyy's avatar
Rayyyyy committed
3
4

import requests
Rayyyyy's avatar
Rayyyyy committed
5
from requests.exceptions import RequestException
Rayyyyy's avatar
Rayyyyy committed
6
7
8
9
10
from tqdm import tqdm

from lm_eval.api.model import LM
from lm_eval.api.registry import register_model

Rayyyyy's avatar
Rayyyyy committed
11
12
13
14
15
16

logger = logging.getLogger(__name__)


def get_result(logprobs, context_length):
    is_greedy = True
Rayyyyy's avatar
Rayyyyy committed
17
18
19
    offsets = logprobs["text_offset"]
    tokens = logprobs["tokens"]
    tokens_logprobs = logprobs["token_logprobs"]
Rayyyyy's avatar
Rayyyyy committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    idx = 0
    while offsets[idx] < context_length:
        idx += 1
    continuation_logprobs = sum(tokens_logprobs[idx:-1])
    for i in range(idx, len(tokens)):
        token = tokens[i]
        top_tokens = logprobs["top_logprobs"][i]
        top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
        if top_token != token:
            is_greedy = False
            break

    return continuation_logprobs, is_greedy


Rayyyyy's avatar
Rayyyyy committed
36
37
38
@register_model("gguf", "ggml")
class GGUFLM(LM):
    def __init__(self, base_url=None, max_length=2048, **kwargs):
Rayyyyy's avatar
Rayyyyy committed
39
40
        super().__init__()
        self.base_url = base_url
Rayyyyy's avatar
Rayyyyy committed
41
        assert self.base_url, "must pass `base_url` to use GGUF LM!"
Rayyyyy's avatar
Rayyyyy committed
42
43
44
45
        self.logprobs = 10
        self.temperature = 0.0
        self.max_length = max_length

Rayyyyy's avatar
Rayyyyy committed
46
47
48
    def gguf_completion(
        self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
    ):
Rayyyyy's avatar
Rayyyyy committed
49
50
51
        for _ in range(retries):
            try:
                prompt = context
Rayyyyy's avatar
Rayyyyy committed
52
53
54
55
56
                request = {
                    "prompt": prompt,
                    "logprobs": self.logprobs,
                    "temperature": self.temperature,
                }
Rayyyyy's avatar
Rayyyyy committed
57
58
                if continuation:
                    prompt += continuation
Rayyyyy's avatar
Rayyyyy committed
59
                    request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
Rayyyyy's avatar
Rayyyyy committed
60
                if stop is not None:
Rayyyyy's avatar
Rayyyyy committed
61
62
63
64
                    request["stop"] = stop
                response = requests.post(
                    f"{self.base_url}/v1/completions", json=request
                )
Rayyyyy's avatar
Rayyyyy committed
65
66
67
68
69
70
71
72
                response.raise_for_status()
                return response.json()
            except RequestException as e:
                logger.error(f"RequestException: {e}")
                time.sleep(delay)  # wait before retrying
        else:
            raise Exception(f"Failed to get a valid response after {retries} retries.")

Rayyyyy's avatar
Rayyyyy committed
73
    def loglikelihood(self, requests, disable_tqdm: bool = False):
Rayyyyy's avatar
Rayyyyy committed
74
75
76
        if not requests:
            return []
        res = []
Rayyyyy's avatar
Rayyyyy committed
77
78
79
        for context, continuation in tqdm(
            [req.args for req in requests], disable=disable_tqdm
        ):
Rayyyyy's avatar
Rayyyyy committed
80
81
82
83
            response = self.gguf_completion(context=context, continuation=continuation)
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                logprobs = choice.get("logprobs")
Rayyyyy's avatar
Rayyyyy committed
84
85
86
87
88
                if (
                    logprobs
                    and "token_logprobs" in logprobs
                    and logprobs["token_logprobs"]
                ):
Rayyyyy's avatar
Rayyyyy committed
89
90
91
                    logprob, is_greedy = get_result(logprobs, len(context))
                    res.append((logprob, is_greedy))
                else:
Rayyyyy's avatar
Rayyyyy committed
92
93
94
                    logger.warning(
                        "Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
                    )
Rayyyyy's avatar
Rayyyyy committed
95
            else:
Rayyyyy's avatar
Rayyyyy committed
96
97
98
                logger.error(
                    f"Invalid response for loglikelihood. Response: {response}"
                )
Rayyyyy's avatar
Rayyyyy committed
99
100
101
                assert False
        return res

Rayyyyy's avatar
Rayyyyy committed
102
    def generate_until(self, requests, disable_tqdm: bool = False):
Rayyyyy's avatar
Rayyyyy committed
103
104
105
106
        if not requests:
            return []

        res = []
Rayyyyy's avatar
Rayyyyy committed
107
        for request in tqdm([req.args for req in requests], disable=disable_tqdm):
Rayyyyy's avatar
Rayyyyy committed
108
109
            inp = request[0]
            request_args = request[1]
Rayyyyy's avatar
Rayyyyy committed
110
            until = request_args.get("until", ["</s>"])
Rayyyyy's avatar
Rayyyyy committed
111
112
113
114
115
116
117
            response = self.gguf_completion(context=inp, stop=until)
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                if "text" in choice:
                    generated_text = choice["text"].strip()
                    res.append(generated_text)
                else:
Rayyyyy's avatar
Rayyyyy committed
118
119
120
                    logger.error(
                        f"Invalid response for greedy_until. Response: {response}"
                    )
Rayyyyy's avatar
Rayyyyy committed
121
122
123
124
125
126
                    res.append(None)  # Add default value in case of error
            else:
                logger.error(f"Invalid response for greedy_until. Response: {response}")
                res.append(None)  # Add default value in case of error
        return res

Rayyyyy's avatar
Rayyyyy committed
127
128
129
130
    def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
        raise NotImplementedError(
            "loglikelihood_rolling not yet supported for GGUF models"
        )