Unverified Commit 7aaceeec authored by Ameya Godbole's avatar Ameya Godbole Committed by GitHub
Browse files

Fix error due in Collating queries with different continuation lengths (fixes #2984) (#2987)



* FIX error due to grouping queries with different continuation length

Make Collator choose query with the longest continuation as the
candidate for generation

* use max for key selection

* added comments explaining variable cont length (identical ctx+cont[:-1])

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 357d4eaa
......@@ -1136,7 +1136,7 @@ class HFLM(TemplateLM):
if self.backend == "causal":
total_length = len(context_enc) + len(continuation_enc)
if total_length > self.max_length + 1:
eval_logger.warn(
eval_logger.warning(
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
f"exceeds model's maximum length ({self.max_length}). "
f"Truncating {total_length - self.max_length + 1} tokens from the left."
......@@ -1247,7 +1247,12 @@ class HFLM(TemplateLM):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
# i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
# by choosing key with longest cont if group_by="contexts".
max_equal = (
greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
......
......@@ -428,9 +428,13 @@ class Collator:
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
# Get one sample from each key.
# Select longest continuation per group to ensure sufficient context logits
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
[
max(value, key=lambda x: len(x[1][-1]))
for value in self._arr_with_indices.values()
]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment