Unverified Commit d3cfdcf6 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #5 from EleutherAI/seq2seq-support

fix edge cases for seq2seq
parents 89ad0186 0500fb33
......@@ -127,7 +127,7 @@ class HFLM(LM):
)
self.model.to(self.device)
else:
self.model = accelerator.prepare(self.model)
self._model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
......@@ -373,7 +373,7 @@ class HFLM(LM):
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
......@@ -519,7 +519,6 @@ class HFLM(LM):
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
......@@ -552,7 +551,8 @@ class HFLM(LM):
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
s = s.split(term)[0]
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
res.append(s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment