"...resnet50_tensorflow.git" did not exist on "14f78b2cbd9424638d76e1c8ae09390e2213e637"
Unverified Commit 8af1970e authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Fix marian slow test (#6854)

parent bbdba0a7
......@@ -38,6 +38,7 @@ if is_torch_available():
convert_hf_name_to_opus_name,
convert_opus_name_to_hf_name,
)
from transformers.modeling_bart import shift_tokens_right
from transformers.pipelines import TranslationPipeline
......@@ -116,18 +117,21 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
desired_keys = {
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"labels",
}
self.assertSetEqual(desired_keys, set(model_inputs.keys()))
model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
model_inputs["return_dict"] = True
model_inputs["use_cache"] = False
with torch.no_grad():
logits, *enc_features = self.model(**model_inputs)
max_indices = logits.argmax(-1)
outputs = self.model(**model_inputs)
max_indices = outputs.logits.argmax(-1)
self.tokenizer.batch_decode(max_indices)
def test_unk_support(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment