"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c94a436dbb080741d4f0f4e7bbc79a31b76dac32"
Commit cf4cd770 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed eos decoding

parent 02e841ce
...@@ -707,15 +707,17 @@ class HFLM(TemplateLM): ...@@ -707,15 +707,17 @@ class HFLM(TemplateLM):
encoding["attention_mask"] = encoding["attention_mask"][ encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len: :, -left_truncate_len:
] ]
# print(encoding["input_ids"][0])
# import sys; sys.exit()
self.tokenizer.padding_side = old_padding_side self.tokenizer.padding_side = old_padding_side
return encoding["input_ids"], encoding["attention_mask"] return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens): def tok_decode(self, tokens, skip_special_tokens=True):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
...@@ -1158,7 +1160,7 @@ class HFLM(TemplateLM): ...@@ -1158,7 +1160,7 @@ class HFLM(TemplateLM):
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences # add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id) eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until: if not until:
until = [eos] until = [eos]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment