"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "6a5b352aaf1e53c490945bd87ebe6ab456b5eda6"
Unverified Commit 47a551d1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[pipeline] Tokenizer should not add special tokens for text generation (#4686)

* allow to not add special tokens

* remove print
parent f6d5046a
...@@ -454,14 +454,17 @@ class Pipeline(_ScikitCompat): ...@@ -454,14 +454,17 @@ class Pipeline(_ScikitCompat):
""" """
return {name: tensor.to(self.device) for name, tensor in inputs.items()} return {name: tensor.to(self.device) for name, tensor in inputs.items()}
def _parse_and_tokenize(self, *args, pad_to_max_length=True, **kwargs): def _parse_and_tokenize(self, *args, pad_to_max_length=True, add_special_tokens=True, **kwargs):
""" """
Parse arguments and tokenize Parse arguments and tokenize
""" """
# Parse arguments # Parse arguments
inputs = self._args_parser(*args, **kwargs) inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer.batch_encode_plus( inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length, inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
pad_to_max_length=pad_to_max_length,
) )
return inputs return inputs
...@@ -617,9 +620,11 @@ class TextGenerationPipeline(Pipeline): ...@@ -617,9 +620,11 @@ class TextGenerationPipeline(Pipeline):
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]: if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text, pad_to_max_length=False) inputs = self._parse_and_tokenize(
self.PADDING_TEXT + prompt_text, pad_to_max_length=False, add_special_tokens=False
)
else: else:
inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False) inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False, add_special_tokens=False)
# set input_ids to None to allow empty prompt # set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0: if inputs["input_ids"].shape[-1] == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment