Unverified Commit 09683883 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Throw an error if `getattribute_from_module` can't find anything (#19535)



* return None to avoid recursive call

* Give error

* Give error

* Add test

* More tests

* Quality
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 383ad81e
...@@ -555,7 +555,14 @@ def getattribute_from_module(module, attr): ...@@ -555,7 +555,14 @@ def getattribute_from_module(module, attr):
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level. # object at the top level.
transformers_module = importlib.import_module("transformers") transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)
if module != transformers_module:
try:
return getattribute_from_module(transformers_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
else:
raise ValueError(f"Could not find {attr} in {transformers_module}!")
class _LazyAutoMapping(OrderedDict): class _LazyAutoMapping(OrderedDict):
......
...@@ -17,9 +17,12 @@ import copy ...@@ -17,9 +17,12 @@ import copy
import sys import sys
import tempfile import tempfile
import unittest import unittest
from collections import OrderedDict
from pathlib import Path from pathlib import Path
from transformers import BertConfig, is_torch_available import pytest
from transformers import BertConfig, GPT2Model, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER,
...@@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase): ...@@ -372,3 +375,22 @@ class AutoModelTest(unittest.TestCase):
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.other_request_count, 0)
def test_attr_not_existing(self):
from transformers.models.auto.auto_factory import _LazyAutoMapping
_CONFIG_MAPPING_NAMES = OrderedDict([("bert", "BertConfig")])
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GhostModel")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
with pytest.raises(ValueError, match=r"Could not find GhostModel neither in .* nor in .*!"):
_MODEL_MAPPING[BertConfig]
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "BertModel")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
self.assertEqual(_MODEL_MAPPING[BertConfig], BertModel)
_MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")])
_MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES)
self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment