Unverified Commit 2f03271d authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

add warning on truncation (#2962)

parent a96085f1
......@@ -430,6 +430,12 @@ class VLLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding]
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
......@@ -480,6 +486,10 @@ class VLLM(TemplateLM):
inputs = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
if full_length := (context_enc + continuation_enc) >= self.max_length:
eval_logger.warning(
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
)
inp = (context_enc + continuation_enc)[-(self.max_length) :]
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment