Commit 3589abbb authored by Benjamin Fattori's avatar Benjamin Fattori Committed by lintangsutawika
Browse files

loglikelihood_rolling for seq2seq

parent 3789d340
......@@ -171,7 +171,34 @@ class Seq2SeqHFLM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
raise NotImplementedError
loglikelihoods = []
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=None,
max_seq_len=self.max_length,
context_len=1,
),
)
)
rolling_token_windows = [(self.eot_token_id,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment