Unverified Commit 35bdecd3 authored by Matt Hoffner's avatar Matt Hoffner Committed by GitHub
Browse files

Merge pull request #1 from LorenzoMinto/master

Return score from continuation logprobs
parents b011af90 9b876402
...@@ -10,6 +10,27 @@ from lm_eval.base import BaseLM ...@@ -10,6 +10,27 @@ from lm_eval.base import BaseLM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_result(logprobs, context_lenght):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_lenght:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GGMLLM(BaseLM): class GGMLLM(BaseLM):
def __init__(self, base_url, truncate=False): def __init__(self, base_url, truncate=False):
super().__init__() super().__init__()
...@@ -17,6 +38,7 @@ class GGMLLM(BaseLM): ...@@ -17,6 +38,7 @@ class GGMLLM(BaseLM):
self.truncate = truncate self.truncate = truncate
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.logpobs = 10 self.logpobs = 10
self.temperature = 0.0
self.max_length = 1024 self.max_length = 1024
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
...@@ -24,9 +46,11 @@ class GGMLLM(BaseLM): ...@@ -24,9 +46,11 @@ class GGMLLM(BaseLM):
for _ in range(retries): for _ in range(retries):
try: try:
prompt = context prompt = context
request = {'prompt': prompt, 'logprobs': self.logpobs,
'temperature': self.temperature}
if continuation: if continuation:
prompt += continuation prompt += continuation
request = {'prompt': prompt, 'logprobs': self.logpobs} request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None: if stop is not None:
request['stop'] = stop request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request) response = requests.post(f"{self.base_url}/v1/completions", json=request)
...@@ -38,7 +62,6 @@ class GGMLLM(BaseLM): ...@@ -38,7 +62,6 @@ class GGMLLM(BaseLM):
else: else:
raise Exception(f"Failed to get a valid response after {retries} retries.") raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests): def loglikelihood(self, requests):
if not requests: if not requests:
return [] return []
...@@ -49,8 +72,7 @@ class GGMLLM(BaseLM): ...@@ -49,8 +72,7 @@ class GGMLLM(BaseLM):
choice = response["choices"][0] choice = response["choices"][0]
logprobs = choice.get("logprobs") logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]: if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob = logprobs["token_logprobs"][0] logprob, is_greedy = get_result(logprobs, len(context))
is_greedy = choice["finish_reason"] == "length"
res.append((logprob, is_greedy)) res.append((logprob, is_greedy))
else: else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.") logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
...@@ -58,7 +80,6 @@ class GGMLLM(BaseLM): ...@@ -58,7 +80,6 @@ class GGMLLM(BaseLM):
logger.error(f"Invalid response for loglikelihood. Response: {response}") logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False assert False
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
...@@ -89,16 +110,15 @@ class GGMLLM(BaseLM): ...@@ -89,16 +110,15 @@ class GGMLLM(BaseLM):
for request in requests: for request in requests:
logprobs = [] logprobs = []
for i in range(0, len(request), self.max_length): for i in range(0, len(request), self.max_length):
chunk = request[i:i+self.max_length] chunk = request[i:i + self.max_length]
chunk_loglikelihood = self.loglikelihood([(chunk, request[i+1:i+self.max_length+1])]) chunk_loglikelihood = self.loglikelihood([(chunk, request[i + 1:i + self.max_length + 1])])
logprobs.extend(chunk_loglikelihood) logprobs.extend(chunk_loglikelihood)
avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs) avg_loglikelihood = sum([logprob for logprob, _ in logprobs]) / len(logprobs)
results.append((avg_loglikelihood, True)) results.append((avg_loglikelihood, True))
return results return results
def _model_call(self, inps): def _model_call(self, inps):
# Placeholder implementation # Placeholder implementation
raise NotImplementedError() raise NotImplementedError()
...@@ -112,7 +132,7 @@ class GGMLLM(BaseLM): ...@@ -112,7 +132,7 @@ class GGMLLM(BaseLM):
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
@property @property
def batch_size(self): def batch_size(self):
# Placeholder implementation # Placeholder implementation
...@@ -128,10 +148,10 @@ class GGMLLM(BaseLM): ...@@ -128,10 +148,10 @@ class GGMLLM(BaseLM):
# Placeholder implementation # Placeholder implementation
raise NotImplementedError() raise NotImplementedError()
def max_length(self): def max_length(self):
return self.max_length return self.max_length
@property @property
def max_gen_toks(self): def max_gen_toks(self):
# Placeholder implementation # Placeholder implementation
raise NotImplementedError() raise NotImplementedError()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment