Commit 71388a7e authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

pad out all batches loglikelihood_rolling. confirmed correctness

parent c10c08a2
...@@ -172,82 +172,47 @@ class HFLM(LM): ...@@ -172,82 +172,47 @@ class HFLM(LM):
# TODO: Implement caching once we've confirmed the perplexity implementation # TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
extra_pad = []
numpad_batches = 0
if self.world_size > 1:
cumulative_batches = 0 # balance token batches among iterators
# compute cumlative batches seen per host
for (string,) in tqdm([req.args for req in requests],disable=True):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
cumulative_batches += len(rolling_token_windows)
cumul_batches_ranks = torch.tensor(cumulative_batches, device = self.device)
gathered_item = self.accelerator.gather(cumul_batches_ranks).cpu().detach().numpy().tolist()
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad_batches = max(gathered_item) - gathered_item[self.rank]
# pad iterators with a pseudodocument
extra_pad = [('pad',)] if max(gathered_item) - min(gathered_item) > 0 else []
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm(extra_pad + [req.args for req in requests],disable=(self.rank != 0)): for (string,) in tqdm([req.args for req in requests],disable=(self.rank != 0)):
if numpad_batches > 0: rolling_token_windows = list(
rolling_token_windows = list( map(
map( utils.make_disjoint_window,
utils.make_disjoint_window, utils.get_rolling_token_windows(
utils.get_rolling_token_windows( token_list=self.tok_encode(string),
token_list=[self.eot_token_id]*self.max_length*numpad_batches, prefix_token=self.eot_token_id,
prefix_token=self.eot_token_id, max_seq_len=self.max_length,
max_seq_len=self.max_length, context_len=1,
context_len=1, ),
),
)
)
else:
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
) )
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that # that
pad_amnt = 0
if self.world_size > 1:
#TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device = self.device)
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt*[rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True rolling_token_windows, disable_tqdm=True
) )
if (numpad_batches > 0) or (string == 'pad'): if (self.world_size > 1) and (pad_amnt > 0):
numpad_batches = 0 string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else: else:
# discard is_greedy # discard is_greedy
string_nll = [x[0] for x in string_nll] string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
...@@ -271,7 +236,8 @@ class HFLM(LM): ...@@ -271,7 +236,8 @@ class HFLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), self.batch_size tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), self.batch_size
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment