ggml.py 5.04 KB
Newer Older
Matt Hoffner's avatar
Matt Hoffner committed
1
import requests
2
import logging
Matt Hoffner's avatar
Matt Hoffner committed
3
import time
Matt Hoffner's avatar
Matt Hoffner committed
4
5
from tqdm import tqdm
from requests.exceptions import RequestException
6
import transformers
Matt Hoffner's avatar
Matt Hoffner committed
7
8
9
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM

10
11
logger = logging.getLogger(__name__)

Matt Hoffner's avatar
Matt Hoffner committed
12

Matt Hoffner's avatar
Matt Hoffner committed
13
class GGMLLM(BaseLM):
Matt Hoffner's avatar
Matt Hoffner committed
14
15
16
17
    def __init__(self, base_url, truncate=False):
        super().__init__()
        self.base_url = base_url
        self.truncate = truncate
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
        self.logpobs = 10
        self.max_length = 1024
        self.vocab_size = self.tokenizer.vocab_size

    def ggml_completion(self, base_url, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
        for _ in range(retries):
            try:
                prompt = context
                if continuation:
                    prompt += continuation
                request = {'prompt': prompt, 'logprobs': self.logpobs}
                if stop is not None:
                    request['stop'] = stop
                response = requests.post(f"{base_url}/v1/completions", json=request)
                response.raise_for_status()
                return response.json()
            except RequestException as e:
                logger.error(f"RequestException: {e}")
                time.sleep(delay)  # wait before retrying
        else:
            raise Exception(f"Failed to get a valid response after {retries} retries. Last exception: {e}")

Matt Hoffner's avatar
Matt Hoffner committed
41
42

    def loglikelihood(self, requests):
Matt Hoffner's avatar
Matt Hoffner committed
43
44
45
        reorderer = Reorderer(requests, len)
        requests = reorderer.get_reordered()

Matt Hoffner's avatar
Matt Hoffner committed
46
47
        res = []
        for context, continuation in tqdm(requests):
48
            response = self.ggml_completion(self.base_url, context=context, continuation=continuation)
49
50
51
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                logprobs = choice.get("logprobs")
52
                if logprobs and "token_logprobs" in logprobs:
Matt Hoffner's avatar
Matt Hoffner committed
53
                    logprob = logprobs["token_logprobs"][0]
54
55
56
57
                    is_greedy = choice["finish_reason"] == "length"
                    res.append((logprob, is_greedy))
                else:
                    logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
Matt Hoffner's avatar
Matt Hoffner committed
58
            else:
59
                logger.error(f"Invalid response for loglikelihood. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
60
                assert False
Matt Hoffner's avatar
Matt Hoffner committed
61
        return reorderer.get_original(res)
62
    
Matt Hoffner's avatar
Matt Hoffner committed
63
64
65
66
67

    def greedy_until(self, requests):
        if not requests:
            return []

Matt Hoffner's avatar
Matt Hoffner committed
68
69
70
        reorderer = Reorderer(requests, len)
        requests = reorderer.get_reordered()

Matt Hoffner's avatar
Matt Hoffner committed
71
72
73
74
75
        res = []
        for request in tqdm(requests):
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
76
            response = self.ggml_completion(self.base_url, context=inp, stop=until)
Matt Hoffner's avatar
Matt Hoffner committed
77
78
79
80
81
82
83
84
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                if "text" in choice:
                    generated_text = choice["text"].strip()
                    res.append(generated_text)
                else:
                    logger.error(f"Invalid response for greedy_until. Response: {response}")
                    res.append(None)  # Add default value in case of error
Matt Hoffner's avatar
Matt Hoffner committed
85
            else:
86
                logger.error(f"Invalid response for greedy_until. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
87
88
89
                res.append(None)  # Add default value in case of error
        return reorderer.get_original(res)

Matt Hoffner's avatar
Matt Hoffner committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    def loglikelihood_rolling(self, requests):
        results = []

        for request in requests:
            logprobs = []
            for i in range(0, len(request), self.max_length):
                chunk = request[i:i+self.max_length]
                chunk_loglikelihood = self.loglikelihood([(chunk, request[i+1:i+self.max_length+1])])
                logprobs.extend(chunk_loglikelihood)
            
            avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
            results.append((avg_loglikelihood, True))

        return results
Matt Hoffner's avatar
Matt Hoffner committed
104

105
106
107
108
109
110
111
112
113
    
    def _model_call(self, inps):
        # Placeholder implementation
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Placeholder implementation
        raise NotImplementedError()

114
115
116
117
118
119
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
    
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @property
    def batch_size(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def device(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def eot_token_id(self):
        # Placeholder implementation
        raise NotImplementedError()

Matt Hoffner's avatar
Matt Hoffner committed
135
    def max_length(self):        
136
        return self.max_length
137
138
139
140

    @property
    def max_gen_toks(self):
        # Placeholder implementation
141
        raise NotImplementedError()