Commit 26a9a445 authored by Benjamin Fattori's avatar Benjamin Fattori Committed by lintangsutawika
Browse files

multi-device support for loglikelihood_rolling

parent 3589abbb
...@@ -7,7 +7,8 @@ import torch.nn.functional as F ...@@ -7,7 +7,8 @@ import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
from accelerate import Accelerator from accelerate import Accelerator
from typing import List from typing import List
...@@ -172,28 +173,42 @@ class Seq2SeqHFLM(LM): ...@@ -172,28 +173,42 @@ class Seq2SeqHFLM(LM):
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm(requests): for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=None, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
) )
) )
rolling_token_windows = [(self.eot_token_id,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, rolling_token_windows, disable_tqdm=True
disable_tqdm=True,
) )
# discard is_greedy if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll] string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment