Commit 8ad386eb authored by baberabb's avatar baberabb
Browse files

added logliklihood_rolling and fixed greedy_until

parent 331340ad
...@@ -134,7 +134,7 @@ class OpenaiCompletionsLM(LM): ...@@ -134,7 +134,7 @@ class OpenaiCompletionsLM(LM):
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[List[float]]: def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -149,13 +149,15 @@ class OpenaiCompletionsLM(LM): ...@@ -149,13 +149,15 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(self, requests, disable_tqdm=False) -> List[List[float]]: def _loglikelihood_tokens(
self, requests, disable_tqdm=False
) -> List[Tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't # we care about, and so we need some kind of backup for when it isn't
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
...@@ -197,13 +199,13 @@ class OpenaiCompletionsLM(LM): ...@@ -197,13 +199,13 @@ class OpenaiCompletionsLM(LM):
# partial caching # partial caching
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests) -> List[str]: def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
res = [] res = []
requests = [req.args for req in requests]
def _collate(x): def _collate(x):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
...@@ -253,7 +255,7 @@ class OpenaiCompletionsLM(LM): ...@@ -253,7 +255,7 @@ class OpenaiCompletionsLM(LM):
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"] s = resp["text"]
until_ = args_.get(["until"], []) until_ = args_.get("until", [])
for term in until_: for term in until_:
if len(term) > 0: if len(term) > 0:
...@@ -265,7 +267,6 @@ class OpenaiCompletionsLM(LM): ...@@ -265,7 +267,6 @@ class OpenaiCompletionsLM(LM):
) )
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
def _model_call(self, inps): def _model_call(self, inps):
...@@ -276,6 +277,33 @@ class OpenaiCompletionsLM(LM): ...@@ -276,6 +277,33 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override greedy_until # Isn't used because we override greedy_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests) -> List[float]:
# Isn't used because we override _loglikelihood_tokens loglikelihoods = []
raise NotImplementedError()
for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment