Unverified Commit 69ed3606 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix BlenderbotSmallTokenizer (#9538)

* add model_input_names

* fix test
parent 2df34f4a
...@@ -92,6 +92,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): ...@@ -92,6 +92,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
}, },
} }
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512} max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
model_input_names = ["attention_mask"]
def __init__( def __init__(
self, self,
......
...@@ -288,8 +288,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): ...@@ -288,8 +288,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device) model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer) assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
generated_ids = self.model.generate(**model_inputs)[0] generated_ids = self.model.generate(**model_inputs)[0]
reply = self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) reply = self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
...@@ -302,8 +300,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): ...@@ -302,8 +300,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
def test_90_generation_from_short_input(self): def test_90_generation_from_short_input(self):
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device) model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
generated_utterances = self.model.generate(**model_inputs) generated_utterances = self.model.generate(**model_inputs)
clean_txt = self.tokenizer.decode( clean_txt = self.tokenizer.decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment