Unverified Commit 1fdc005e authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

cleanup and fix variable names for GGUF model

parent 97bc9780
...@@ -10,14 +10,14 @@ from lm_eval.base import BaseLM ...@@ -10,14 +10,14 @@ from lm_eval.base import BaseLM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_result(logprobs, context_lenght): def get_result(logprobs, context_length):
is_greedy = True is_greedy = True
offsets = logprobs['text_offset'] offsets = logprobs['text_offset']
tokens = logprobs['tokens'] tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs'] tokens_logprobs = logprobs['token_logprobs']
idx = 0 idx = 0
while offsets[idx] < context_lenght: while offsets[idx] < context_length:
idx += 1 idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1]) continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)): for i in range(idx, len(tokens)):
...@@ -32,21 +32,19 @@ def get_result(logprobs, context_lenght): ...@@ -32,21 +32,19 @@ def get_result(logprobs, context_lenght):
class GGUFLM(BaseLM): class GGUFLM(BaseLM):
def __init__(self, base_url, truncate=False): def __init__(self, base_url, max_length=2048):
super().__init__() super().__init__()
self.base_url = base_url self.base_url = base_url
self.truncate = truncate self.truncate = truncate
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") self.logprobs = 10
self.logpobs = 10
self.temperature = 0.0 self.temperature = 0.0
self.max_length = 1024 self.max_length = max_length
self.vocab_size = self.tokenizer.vocab_size
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs): def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries): for _ in range(retries):
try: try:
prompt = context prompt = context
request = {'prompt': prompt, 'logprobs': self.logpobs, request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature} 'temperature': self.temperature}
if continuation: if continuation:
prompt += continuation prompt += continuation
...@@ -105,19 +103,7 @@ class GGUFLM(BaseLM): ...@@ -105,19 +103,7 @@ class GGUFLM(BaseLM):
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
results = [] raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
for request in requests:
logprobs = []
for i in range(0, len(request), self.max_length):
chunk = request[i:i + self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i + 1:i + self.max_length + 1])])
logprobs.extend(chunk_loglikelihood)
avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
results.append((avg_loglikelihood, True))
return results
def _model_call(self, inps): def _model_call(self, inps):
# Placeholder implementation # Placeholder implementation
...@@ -128,10 +114,10 @@ class GGUFLM(BaseLM): ...@@ -128,10 +114,10 @@ class GGUFLM(BaseLM):
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) raise NotImplementedError()
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) raise NotImplementedError()
@property @property
def batch_size(self): def batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment