ggml.py 4.83 KB
Newer Older
Matt Hoffner's avatar
Matt Hoffner committed
1
import requests
2
import logging
Matt Hoffner's avatar
Matt Hoffner committed
3
import time
Matt Hoffner's avatar
Matt Hoffner committed
4
5
from tqdm import tqdm
from requests.exceptions import RequestException
6
import transformers
Matt Hoffner's avatar
Matt Hoffner committed
7
8
9
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM

10
11
logger = logging.getLogger(__name__)

Matt Hoffner's avatar
Matt Hoffner committed
12

Matt Hoffner's avatar
Matt Hoffner committed
13
class GGMLLM(BaseLM):
Matt Hoffner's avatar
Matt Hoffner committed
14
15
16
17
    def __init__(self, base_url, truncate=False):
        super().__init__()
        self.base_url = base_url
        self.truncate = truncate
18
19
20
21
22
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
        self.logpobs = 10
        self.max_length = 1024
        self.vocab_size = self.tokenizer.vocab_size

23
    def ggml_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
24
25
26
27
28
29
30
31
        for _ in range(retries):
            try:
                prompt = context
                if continuation:
                    prompt += continuation
                request = {'prompt': prompt, 'logprobs': self.logpobs}
                if stop is not None:
                    request['stop'] = stop
32
                response = requests.post(f"{self.base_url}/v1/completions", json=request)
33
34
35
36
37
38
                response.raise_for_status()
                return response.json()
            except RequestException as e:
                logger.error(f"RequestException: {e}")
                time.sleep(delay)  # wait before retrying
        else:
39
            raise Exception(f"Failed to get a valid response after {retries} retries.")
40

Matt Hoffner's avatar
Matt Hoffner committed
41
42

    def loglikelihood(self, requests):
43
44
        if not requests:
            return []
Matt Hoffner's avatar
Matt Hoffner committed
45
46
        res = []
        for context, continuation in tqdm(requests):
47
            response = self.ggml_completion(context=context, continuation=continuation)
48
49
50
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                logprobs = choice.get("logprobs")
51
                if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
Matt Hoffner's avatar
Matt Hoffner committed
52
                    logprob = logprobs["token_logprobs"][0]
53
54
55
56
                    is_greedy = choice["finish_reason"] == "length"
                    res.append((logprob, is_greedy))
                else:
                    logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
Matt Hoffner's avatar
Matt Hoffner committed
57
            else:
58
                logger.error(f"Invalid response for loglikelihood. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
59
                assert False
60
        return res
61
    
Matt Hoffner's avatar
Matt Hoffner committed
62
63
64
65
66
67
68
69
70
71

    def greedy_until(self, requests):
        if not requests:
            return []

        res = []
        for request in tqdm(requests):
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
72
            response = self.ggml_completion(context=inp, stop=until)
Matt Hoffner's avatar
Matt Hoffner committed
73
74
75
76
77
78
79
80
            if response and "choices" in response and response["choices"]:
                choice = response["choices"][0]
                if "text" in choice:
                    generated_text = choice["text"].strip()
                    res.append(generated_text)
                else:
                    logger.error(f"Invalid response for greedy_until. Response: {response}")
                    res.append(None)  # Add default value in case of error
Matt Hoffner's avatar
Matt Hoffner committed
81
            else:
82
                logger.error(f"Invalid response for greedy_until. Response: {response}")
Matt Hoffner's avatar
Matt Hoffner committed
83
                res.append(None)  # Add default value in case of error
84
        return res
Matt Hoffner's avatar
Matt Hoffner committed
85

Matt Hoffner's avatar
Matt Hoffner committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    def loglikelihood_rolling(self, requests):
        results = []

        for request in requests:
            logprobs = []
            for i in range(0, len(request), self.max_length):
                chunk = request[i:i+self.max_length]
                chunk_loglikelihood = self.loglikelihood([(chunk, request[i+1:i+self.max_length+1])])
                logprobs.extend(chunk_loglikelihood)
            
            avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
            results.append((avg_loglikelihood, True))

        return results
Matt Hoffner's avatar
Matt Hoffner committed
100

101
102
103
104
105
106
107
108
109
    
    def _model_call(self, inps):
        # Placeholder implementation
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Placeholder implementation
        raise NotImplementedError()

110
111
112
113
114
115
    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)
    
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    @property
    def batch_size(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def device(self):
        # Placeholder implementation
        raise NotImplementedError()

    @property
    def eot_token_id(self):
        # Placeholder implementation
        raise NotImplementedError()

Matt Hoffner's avatar
Matt Hoffner committed
131
    def max_length(self):        
132
        return self.max_length
133
134
135
136

    @property
    def max_gen_toks(self):
        # Placeholder implementation
137
        raise NotImplementedError()