Unverified Commit 1fdc005e authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

cleanup and fix variable names for GGUF model

parent 97bc9780
......@@ -10,14 +10,14 @@ from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def get_result(logprobs, context_lenght):
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_lenght:
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
......@@ -32,21 +32,19 @@ def get_result(logprobs, context_lenght):
class GGUFLM(BaseLM):
def __init__(self, base_url, truncate=False):
def __init__(self, base_url, max_length=2048):
super().__init__()
self.base_url = base_url
self.truncate = truncate
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.logpobs = 10
self.logprobs = 10
self.temperature = 0.0
self.max_length = 1024
self.vocab_size = self.tokenizer.vocab_size
self.max_length = max_length
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logpobs,
request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
......@@ -105,19 +103,7 @@ class GGUFLM(BaseLM):
return res
def loglikelihood_rolling(self, requests):
results = []
for request in requests:
logprobs = []
for i in range(0, len(request), self.max_length):
chunk = request[i:i + self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i + 1:i + self.max_length + 1])])
logprobs.extend(chunk_loglikelihood)
avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
results.append((avg_loglikelihood, True))
return results
raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
def _model_call(self, inps):
# Placeholder implementation
......@@ -128,10 +114,10 @@ class GGUFLM(BaseLM):
raise NotImplementedError()
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
raise NotImplementedError()
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
raise NotImplementedError()
@property
def batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment