Commit c9c141d2 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add err handling for multi-tok stopseq

parent 72b7f0c0
......@@ -176,7 +176,9 @@ class BaseLM(LM):
def _detect_batch_size(self, requests=None, pos=0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)
else:
max_length = self.max_length
......@@ -212,7 +214,9 @@ class BaseLM(LM):
for context, continuation in requests:
if context == "":
# end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation)
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
......@@ -290,15 +294,23 @@ class BaseLM(LM):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size")
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
for chunk in utils.chunks(
tqdm(reordered_requests, disable=disable_tqdm),
n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0,
fn=_batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 else None,
n=self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=_batch_scheduler
if self.batch_size == "auto" and n_reordered_requests > 0
else None,
):
inps = []
cont_toks_list = []
......@@ -411,13 +423,22 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate)
warn_stop_seq = False
for context, request_args in tqdm(re_ord.get_reordered()):
until = request_args["until"]
if isinstance(until, str):
until = [until]
if until:
(primary_until,) = self.tok_encode(until[0])
try:
(primary_until,) = self.tok_encode(until[0])
except ValueError:
if not warn_stop_seq:
print(
"Warning: a primary stop sequence is multi-token! Will default to EOS token for this tokenizer. Consider using `hf-causal-experimental` for multi-token stop sequence support for the time being."
)
warn_stop_seq = True
primary_until = self.eot_token_id
else:
primary_until = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment