Unverified Commit 284dd80d authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

always include EOS token in stopsequences if possible (#1480)

parent d272c19f
......@@ -1151,8 +1151,12 @@ class HFLM(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until:
until = [self.tok_decode(self.eot_token_id)]
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -666,8 +666,12 @@ class NEURON_HF(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until:
until = [self.tok_decode(self.eot_token_id)]
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -282,8 +282,12 @@ class VLLM(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until:
until = [self.tokenizer.decode(self.eot_token_id)]
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment