Unverified Commit 35bdecd3 authored by Matt Hoffner's avatar Matt Hoffner Committed by GitHub
Browse files

Merge pull request #1 from LorenzoMinto/master

Return score from continuation logprobs
parents b011af90 9b876402
......@@ -10,6 +10,27 @@ from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def get_result(logprobs, context_lenght):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_lenght:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GGMLLM(BaseLM):
def __init__(self, base_url, truncate=False):
super().__init__()
......@@ -17,6 +38,7 @@ class GGMLLM(BaseLM):
self.truncate = truncate
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.logpobs = 10
self.temperature = 0.0
self.max_length = 1024
self.vocab_size = self.tokenizer.vocab_size
......@@ -24,9 +46,11 @@ class GGMLLM(BaseLM):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logpobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
request = {'prompt': prompt, 'logprobs': self.logpobs}
request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None:
request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request)
......@@ -38,7 +62,6 @@ class GGMLLM(BaseLM):
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
......@@ -49,8 +72,7 @@ class GGMLLM(BaseLM):
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob = logprobs["token_logprobs"][0]
is_greedy = choice["finish_reason"] == "length"
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
......@@ -58,7 +80,6 @@ class GGMLLM(BaseLM):
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
def greedy_until(self, requests):
if not requests:
......@@ -89,16 +110,15 @@ class GGMLLM(BaseLM):
for request in requests:
logprobs = []
for i in range(0, len(request), self.max_length):
chunk = request[i:i+self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i+1:i+self.max_length+1])])
chunk = request[i:i + self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i + 1:i + self.max_length + 1])])
logprobs.extend(chunk_loglikelihood)
avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
results.append((avg_loglikelihood, True))
return results
def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()
......@@ -112,7 +132,7 @@ class GGMLLM(BaseLM):
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
@property
def batch_size(self):
# Placeholder implementation
......@@ -128,10 +148,10 @@ class GGMLLM(BaseLM):
# Placeholder implementation
raise NotImplementedError()
def max_length(self):
def max_length(self):
return self.max_length
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
\ No newline at end of file
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment