Unverified Commit 2e60276b authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[M2M100Tokenizer] fix _build_translation_inputs (#14382)



* add return_tensors paramter

* fix test

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 31659304
...@@ -332,7 +332,7 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -332,7 +332,7 @@ class M2M100Tokenizer(PreTrainedTokenizer):
if src_lang is None or tgt_lang is None: if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) inputs = self(raw_inputs, add_special_tokens=True, **extra_kwargs)
tgt_lang_id = self.get_lang_id(tgt_lang) tgt_lang_id = self.get_lang_id(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs
......
...@@ -226,7 +226,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): ...@@ -226,7 +226,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
@require_torch @require_torch
def test_tokenizer_translation(self): def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar") inputs = self.tokenizer._build_translation_inputs("A test", return_tensors="pt", src_lang="en", tgt_lang="ar")
self.assertEqual( self.assertEqual(
nested_simplify(inputs), nested_simplify(inputs),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment