"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2ae678229fbacc588e3c16f9183d2c88add83521"
Unverified Commit 7e442874 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Conversion to tensors requires padding (#10661)

parent 2adc8c92
...@@ -354,7 +354,9 @@ class MarianIntegrationTest(unittest.TestCase): ...@@ -354,7 +354,9 @@ class MarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device) model_inputs = self.tokenizer(self.src_text, padding=True, return_tensors="pt", **tokenizer_kwargs).to(
torch_device
)
self.assertEqual(self.model.device, model_inputs.input_ids.device) self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
......
...@@ -363,7 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase): ...@@ -363,7 +363,7 @@ class AbstractMarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf") model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, padding=True, return_tensors="tf")
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment