import requests import logging import time from tqdm import tqdm from requests.exceptions import RequestException import transformers from lm_eval.utils import Reorderer from lm_eval.base import BaseLM logger = logging.getLogger(__name__) def get_result(logprobs, context_length): is_greedy = True offsets = logprobs['text_offset'] tokens = logprobs['tokens'] tokens_logprobs = logprobs['token_logprobs'] idx = 0 while offsets[idx] < context_length: idx += 1 continuation_logprobs = sum(tokens_logprobs[idx:-1]) for i in range(idx, len(tokens)): token = tokens[i] top_tokens = logprobs["top_logprobs"][i] top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x]) if top_token != token: is_greedy = False break return continuation_logprobs, is_greedy class GGUFLM(BaseLM): def __init__(self, base_url, max_length=2048): super().__init__() self.base_url = base_url self.logprobs = 10 self.temperature = 0.0 self.max_length = max_length def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs): for _ in range(retries): try: prompt = context request = {'prompt': prompt, 'logprobs': self.logprobs, 'temperature': self.temperature} if continuation: prompt += continuation request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True}) if stop is not None: request['stop'] = stop response = requests.post(f"{self.base_url}/v1/completions", json=request) response.raise_for_status() return response.json() except RequestException as e: logger.error(f"RequestException: {e}") time.sleep(delay) # wait before retrying else: raise Exception(f"Failed to get a valid response after {retries} retries.") def loglikelihood(self, requests): if not requests: return [] res = [] for context, continuation in tqdm(requests): response = self.gguf_completion(context=context, continuation=continuation) if response and "choices" in response and response["choices"]: choice = response["choices"][0] logprobs = choice.get("logprobs") if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]: logprob, is_greedy = get_result(logprobs, len(context)) res.append((logprob, is_greedy)) else: logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.") else: logger.error(f"Invalid response for loglikelihood. Response: {response}") assert False return res def greedy_until(self, requests): if not requests: return [] res = [] for request in tqdm(requests): inp = request[0] request_args = request[1] until = request_args["until"] response = self.gguf_completion(context=inp, stop=until) if response and "choices" in response and response["choices"]: choice = response["choices"][0] if "text" in choice: generated_text = choice["text"].strip() res.append(generated_text) else: logger.error(f"Invalid response for greedy_until. Response: {response}") res.append(None) # Add default value in case of error else: logger.error(f"Invalid response for greedy_until. Response: {response}") res.append(None) # Add default value in case of error return res def loglikelihood_rolling(self, requests): raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models") def _model_call(self, inps): # Placeholder implementation raise NotImplementedError() def _model_generate(self, context, max_length, eos_token_id): # Placeholder implementation raise NotImplementedError() def tok_encode(self, string: str): raise NotImplementedError() def tok_decode(self, tokens): raise NotImplementedError() @property def batch_size(self): # Placeholder implementation raise NotImplementedError() @property def device(self): # Placeholder implementation raise NotImplementedError() @property def eot_token_id(self): # Placeholder implementation raise NotImplementedError() def max_length(self): return self.max_length @property def max_gen_toks(self): # Placeholder implementation raise NotImplementedError()