Unverified Commit b31f92e8 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

API: fix maxlen; vllm: prefix_token_id bug (#2262)

* max_length - 1 (generation always >= 1)

* vllm: fix rolling prefix_token

* nit: add comment

* fixup! max_length should be handled for logliklihoods
parent 8138fd52
...@@ -104,6 +104,7 @@ class TemplateAPI(TemplateLM): ...@@ -104,6 +104,7 @@ class TemplateAPI(TemplateLM):
self._truncate = truncate self._truncate = truncate
self._max_gen_toks = int(max_gen_toks) self._max_gen_toks = int(max_gen_toks)
self._seed = int(seed) self._seed = int(seed)
eval_logger.info(f"Using max length {max_length}")
self.max_length = max_length self.max_length = max_length
if int(num_concurrent) <= 1: if int(num_concurrent) <= 1:
eval_logger.info( eval_logger.info(
...@@ -417,9 +418,10 @@ class TemplateAPI(TemplateLM): ...@@ -417,9 +418,10 @@ class TemplateAPI(TemplateLM):
cache_keys = [] cache_keys = []
for chunk in chunks: for chunk in chunks:
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :] # max_length - 1 as we always have 1 token for generation
inp = (context_enc + continuation_enc)[-(self.max_length - 1) :]
ctxlen = len(context_enc) - max( ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - (self.max_length - 1)
) )
inputs.append(inp) inputs.append(inp)
...@@ -619,7 +621,8 @@ class TemplateAPI(TemplateLM): ...@@ -619,7 +621,8 @@ class TemplateAPI(TemplateLM):
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.prefix_token_id, prefix_token=self.prefix_token_id,
max_seq_len=self.max_length, # max_seq_len - (1 for context)
max_seq_len=self.max_length - 1,
context_len=1, context_len=1,
), ),
) )
......
...@@ -289,7 +289,8 @@ class VLLM(TemplateLM): ...@@ -289,7 +289,8 @@ class VLLM(TemplateLM):
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.prefix_token_id,
# max_seq_len - (1 for context)
max_seq_len=self.max_length - 1, max_seq_len=self.max_length - 1,
context_len=1, context_len=1,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment