"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c888663f18673003574cffe9608c5aae2bc9ccff"
Unverified Commit 1ddf3c2b authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix vit test (#15671)

parent 943e2aa0
...@@ -113,19 +113,13 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -113,19 +113,13 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
small_model = "lysandre/tiny-vit-random" small_model = "hf-internal-testing/tiny-random-vit"
image_classifier = pipeline("image-classification", model=small_model) image_classifier = pipeline("image-classification", model=small_model)
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg") outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
{"score": 0.0014, "label": "trench coat"},
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
{"score": 0.0014, "label": "baboon"},
],
) )
outputs = image_classifier( outputs = image_classifier(
...@@ -138,32 +132,20 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -138,32 +132,20 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"}, [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
], ],
) )
@require_tf @require_tf
def test_small_model_tf(self): def test_small_model_tf(self):
small_model = "lysandre/tiny-vit-random" small_model = "hf-internal-testing/tiny-random-vit"
image_classifier = pipeline("image-classification", model=small_model) image_classifier = pipeline("image-classification", model=small_model)
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg") outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
{"score": 0.0014, "label": "trench coat"},
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
{"score": 0.0014, "label": "baboon"},
],
) )
outputs = image_classifier( outputs = image_classifier(
...@@ -176,14 +158,8 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -176,14 +158,8 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"}, [{"label": "LABEL_1", "score": 0.574}, {"label": "LABEL_0", "score": 0.426}],
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
], ],
) )
...@@ -191,7 +167,9 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -191,7 +167,9 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
tokenizer = PreTrainedTokenizer() tokenizer = PreTrainedTokenizer()
# Assert that the pipeline can be initialized with a feature extractor that is not in any mapping # Assert that the pipeline can be initialized with a feature extractor that is not in any mapping
image_classifier = pipeline("image-classification", model="lysandre/tiny-vit-random", tokenizer=tokenizer) image_classifier = pipeline(
"image-classification", model="hf-internal-testing/tiny-random-vit", tokenizer=tokenizer
)
self.assertIs(image_classifier.tokenizer, tokenizer) self.assertIs(image_classifier.tokenizer, tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment