"docs/source/en/api/loaders.md" did not exist on "8bf80fc8d8aade3bd3fca5054d05b65488fbbf8f"
Unverified Commit bc10a390 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #630 from EleutherAI/fix-stopseq

Add error handling for multi-token stopseq and `hf-causal` model type
parents 72b7f0c0 c9c141d2
...@@ -176,7 +176,9 @@ class BaseLM(LM): ...@@ -176,7 +176,9 @@ class BaseLM(LM):
def _detect_batch_size(self, requests=None, pos=0): def _detect_batch_size(self, requests=None, pos=0):
if requests: if requests:
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
max_length = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)
else: else:
max_length = self.max_length max_length = self.max_length
...@@ -212,7 +214,9 @@ class BaseLM(LM): ...@@ -212,7 +214,9 @@ class BaseLM(LM):
for context, continuation in requests: for context, continuation in requests:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation) context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -290,15 +294,23 @@ class BaseLM(LM): ...@@ -290,15 +294,23 @@ class BaseLM(LM):
sched = pos // int(n_reordered_requests / self.batch_schedule) sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes: if sched in self.batch_sizes:
return self.batch_sizes[sched] return self.batch_sizes[sched]
print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size") print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos) self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}") print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched] return self.batch_sizes[sched]
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(reordered_requests, disable=disable_tqdm), tqdm(reordered_requests, disable=disable_tqdm),
n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0, n=self.batch_size
fn=_batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 else None, if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=_batch_scheduler
if self.batch_size == "auto" and n_reordered_requests > 0
else None,
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
...@@ -411,13 +423,22 @@ class BaseLM(LM): ...@@ -411,13 +423,22 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
warn_stop_seq = False
for context, request_args in tqdm(re_ord.get_reordered()): for context, request_args in tqdm(re_ord.get_reordered()):
until = request_args["until"] until = request_args["until"]
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
if until: if until:
(primary_until,) = self.tok_encode(until[0]) try:
(primary_until,) = self.tok_encode(until[0])
except ValueError:
if not warn_stop_seq:
print(
"Warning: a primary stop sequence is multi-token! Will default to EOS token for this tokenizer. Consider using `hf-causal-experimental` for multi-token stop sequence support for the time being."
)
warn_stop_seq = True
primary_until = self.eot_token_id
else: else:
primary_until = None primary_until = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment