"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6b29bff852a6dafa4e2c854a4dca19836cb7c72d"
Unverified Commit 693720e5 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix LayoutXLMProcessorTest (#17506)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 4d1ce396
...@@ -22,7 +22,6 @@ from typing import List ...@@ -22,7 +22,6 @@ from typing import List
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir,
require_pytesseract, require_pytesseract,
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
...@@ -38,9 +37,6 @@ if is_pytesseract_available(): ...@@ -38,9 +37,6 @@ if is_pytesseract_available():
from transformers import LayoutLMv2FeatureExtractor, LayoutXLMProcessor from transformers import LayoutLMv2FeatureExtractor, LayoutXLMProcessor
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
@require_pytesseract @require_pytesseract
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
...@@ -60,11 +56,14 @@ class LayoutXLMProcessorTest(unittest.TestCase): ...@@ -60,11 +56,14 @@ class LayoutXLMProcessorTest(unittest.TestCase):
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n") fp.write(json.dumps(feature_extractor_map) + "\n")
# taken from `test_tokenization_layoutxlm.LayoutXLMTokenizationTest.test_save_pretrained`
self.tokenizer_pretrained_name = "hf-internal-testing/tiny-random-layoutxlm"
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs) return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
return self.rust_tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs) return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]: def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)] return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment