"examples/vscode:/vscode.git/clone" did not exist on "e4f59ba073ee55dd4d720db8a8883220859488b1"
Unverified Commit 1fdc005e authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

cleanup and fix variable names for GGUF model

parent 97bc9780
...@@ -10,14 +10,14 @@ from lm_eval.base import BaseLM ...@@ -10,14 +10,14 @@ from lm_eval.base import BaseLM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_result(logprobs, context_lenght): def get_result(logprobs, context_length):
is_greedy = True is_greedy = True
offsets = logprobs['text_offset'] offsets = logprobs['text_offset']
tokens = logprobs['tokens'] tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs'] tokens_logprobs = logprobs['token_logprobs']
idx = 0 idx = 0
while offsets[idx] < context_lenght: while offsets[idx] < context_length:
idx += 1 idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1]) continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)): for i in range(idx, len(tokens)):
...@@ -32,21 +32,19 @@ def get_result(logprobs, context_lenght): ...@@ -32,21 +32,19 @@ def get_result(logprobs, context_lenght):
class GGUFLM(BaseLM): class GGUFLM(BaseLM):
def __init__(self, base_url, truncate=False): def __init__(self, base_url, max_length=2048):
super().__init__() super().__init__()
self.base_url = base_url self.base_url = base_url
self.truncate = truncate self.truncate = truncate
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") self.logprobs = 10
self.logpobs = 10
self.temperature = 0.0 self.temperature = 0.0
self.max_length = 1024 self.max_length = max_length
self.vocab_size = self.tokenizer.vocab_size
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs): def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries): for _ in range(retries):
try: try:
prompt = context prompt = context
request = {'prompt': prompt, 'logprobs': self.logpobs, request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature} 'temperature': self.temperature}
if continuation: if continuation:
prompt += continuation prompt += continuation
...@@ -105,19 +103,7 @@ class GGUFLM(BaseLM): ...@@ -105,19 +103,7 @@ class GGUFLM(BaseLM):
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
results = [] raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
for request in requests:
logprobs = []
for i in range(0, len(request), self.max_length):
chunk = request[i:i + self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i + 1:i + self.max_length + 1])])
logprobs.extend(chunk_loglikelihood)
avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
results.append((avg_loglikelihood, True))
return results
def _model_call(self, inps): def _model_call(self, inps):
# Placeholder implementation # Placeholder implementation
...@@ -128,10 +114,10 @@ class GGUFLM(BaseLM): ...@@ -128,10 +114,10 @@ class GGUFLM(BaseLM):
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) raise NotImplementedError()
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) raise NotImplementedError()
@property @property
def batch_size(self): def batch_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment