Unverified Commit 9879a1d5 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix LayoutLMv2 test (#15939)

* Fix LayoutLMv2 test

* Update black
parent 8b9ae455
......@@ -31,14 +31,7 @@ from transformers.models.layoutlmv2.tokenization_layoutlmv2 import (
_is_punctuation,
_is_whitespace,
)
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_pandas,
require_scatter,
require_tokenizers,
require_torch,
slow,
)
from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
from ..test_tokenization_common import (
SMALL_TRAINING_CORPUS,
......@@ -1219,7 +1212,6 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@require_torch
@slow
@require_scatter
def test_torch_encode_plus_sent_to_model(self):
import torch
......@@ -1254,10 +1246,15 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
words, boxes = self.get_words_and_boxes()
encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_tensors="pt")
batch_encoded_sequence = tokenizer.batch_encode_plus(
[words, words], [boxes, boxes], return_tensors="pt"
[words, words], boxes=[boxes, boxes], return_tensors="pt"
)
# This should not fail
# We add dummy image keys (as LayoutLMv2 actually also requires a feature extractor
# to prepare the image input)
encoded_sequence["image"] = torch.randn(1, 3, 224, 224)
batch_encoded_sequence["image"] = torch.randn(2, 3, 224, 224)
# This should not fail
with torch.no_grad(): # saves some time
model(**encoded_sequence)
model(**batch_encoded_sequence)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment