"vscode:/vscode.git/clone" did not exist on "f586c415eb04c02ea6729f92df7b096688b468ca"
Commit 2bfc3ce4 authored by Baber's avatar Baber
Browse files

if max_length reached truncate generations

parent 6f66224b
...@@ -283,6 +283,10 @@ class VLLM(TemplateLM): ...@@ -283,6 +283,10 @@ class VLLM(TemplateLM):
@property @property
def max_length(self): def max_length(self):
return 8096 if self._max_length > 8096 else self._max_length
@property
def _max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
return self._max_length return self._max_length
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
...@@ -627,13 +631,9 @@ class VLLM(TemplateLM): ...@@ -627,13 +631,9 @@ class VLLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks default_length = len(x) + max_gen_toks
if len(x) > max_ctx_len: if default_length > self.max_length:
eval_logger.warning( max_gen_toks = self.max_length - len(x)
f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x) context_encoding_truncated.append(x)
# create sampling params # create sampling params
kwargs = self.modify_gen_kwargs(kwargs) kwargs = self.modify_gen_kwargs(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment