Unverified Commit d3cfdcf6 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #5 from EleutherAI/seq2seq-support

fix edge cases for seq2seq
parents 89ad0186 0500fb33
...@@ -127,7 +127,7 @@ class HFLM(LM): ...@@ -127,7 +127,7 @@ class HFLM(LM):
) )
self.model.to(self.device) self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator self.accelerator = accelerator
...@@ -373,7 +373,7 @@ class HFLM(LM): ...@@ -373,7 +373,7 @@ class HFLM(LM):
assert len(continuation_enc) > 0 assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length assert len(continuation_enc) <= self.max_length
# how this all works: # how this all works (illustrated on a causal decoder-only setup):
# CTX CONT # CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \ # model \ \
...@@ -519,7 +519,6 @@ class HFLM(LM): ...@@ -519,7 +519,6 @@ class HFLM(LM):
max_gen_toks = gen_kwargs.pop("max_gen_toks") max_gen_toks = gen_kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering # first stop sequence is used to halt generation upon encountering
(primary_until) = until[0] (primary_until) = until[0]
...@@ -552,7 +551,8 @@ class HFLM(LM): ...@@ -552,7 +551,8 @@ class HFLM(LM):
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until: for term in until:
s = s.split(term)[0] if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
res.append(s) res.append(s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment