Unverified Commit 5ba2dbd9 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `AutoModelTest.test_model_from_pretrained` (#20730)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a3345c1f
......@@ -22,7 +22,7 @@ from pathlib import Path
import pytest
from transformers import BertConfig, GPT2Model, is_torch_available
from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
......@@ -102,7 +102,10 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(model, BertModel)
self.assertEqual(len(loading_info["missing_keys"]), 0)
self.assertEqual(len(loading_info["unexpected_keys"]), 8)
# When using PyTorch checkpoint, the expected value is `8`. With `safetensors` checkpoint (if it is
# installed), the expected value becomes `7`.
EXPECTED_NUM_OF_UNEXPECTED_KEYS = 7 if is_safetensors_available() else 8
self.assertEqual(len(loading_info["unexpected_keys"]), EXPECTED_NUM_OF_UNEXPECTED_KEYS)
self.assertEqual(len(loading_info["mismatched_keys"]), 0)
self.assertEqual(len(loading_info["error_msgs"]), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment