Unverified Commit 0c92e7d9 authored by Santiago Castro's avatar Santiago Castro Committed by GitHub
Browse files

Fix ignore list behavior in doctests (#8213)

parent 84caa233
...@@ -494,8 +494,8 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r""" ...@@ -494,8 +494,8 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
- The model is a model provided by the library (loaded with the `shortcut name` string of a - The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model). pretrained model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory. by supplying the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory. configuration JSON file named `config.json` is found in the directory.
state_dict (`Dict[str, torch.Tensor]`, `optional`): state_dict (`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file. A state dictionary to use instead of a state dictionary loaded from saved weights file.
......
...@@ -550,8 +550,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -550,8 +550,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
- The model is a model provided by the library (loaded with the `shortcut name` string of a - The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model). pretrained model).
- The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory. by supplying the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory. configuration JSON file named `config.json` is found in the directory.
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`): from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch state_dict save file (see docstring of Load the model weights from a PyTorch state_dict save file (see docstring of
......
...@@ -784,8 +784,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -784,8 +784,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- The model is a model provided by the library (loaded with the `shortcut name` string of a - The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model). pretrained model).
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory. by supplying the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory. configuration JSON file named `config.json` is found in the directory.
state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`): state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file. A state dictionary to use instead of a state dictionary loaded from saved weights file.
......
...@@ -36,8 +36,8 @@ class TestCodeExamples(unittest.TestCase): ...@@ -36,8 +36,8 @@ class TestCodeExamples(unittest.TestCase):
self, self,
directory: Path, directory: Path,
identifier: Union[str, None] = None, identifier: Union[str, None] = None,
ignore_files: Union[List[str], None] = [], ignore_files: Union[List[str], None] = None,
n_identifier: Union[str, None] = None, n_identifier: Union[str, List[str], None] = None,
only_modules: bool = True, only_modules: bool = True,
): ):
""" """
...@@ -45,7 +45,7 @@ class TestCodeExamples(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestCodeExamples(unittest.TestCase):
the doctests in those files the doctests in those files
Args: Args:
directory (:obj:`str`): Directory containing the files directory (:obj:`Path`): Directory containing the files
identifier (:obj:`str`): Will parse files containing this identifier (:obj:`str`): Will parse files containing this
ignore_files (:obj:`List[str]`): List of files to skip ignore_files (:obj:`List[str]`): List of files to skip
n_identifier (:obj:`str` or :obj:`List[str]`): Will not parse files containing this/these identifiers. n_identifier (:obj:`str` or :obj:`List[str]`): Will not parse files containing this/these identifiers.
...@@ -63,6 +63,7 @@ class TestCodeExamples(unittest.TestCase): ...@@ -63,6 +63,7 @@ class TestCodeExamples(unittest.TestCase):
else: else:
files = [file for file in files if n_identifier not in file] files = [file for file in files if n_identifier not in file]
ignore_files = ignore_files or []
ignore_files.append("__init__.py") ignore_files.append("__init__.py")
files = [file for file in files if file not in ignore_files] files = [file for file in files if file not in ignore_files]
...@@ -71,8 +72,8 @@ class TestCodeExamples(unittest.TestCase): ...@@ -71,8 +72,8 @@ class TestCodeExamples(unittest.TestCase):
print("Testing", file) print("Testing", file)
if only_modules: if only_modules:
try:
module_identifier = file.split(".")[0] module_identifier = file.split(".")[0]
try:
module_identifier = getattr(transformers, module_identifier) module_identifier = getattr(transformers, module_identifier)
suite = doctest.DocTestSuite(module_identifier) suite = doctest.DocTestSuite(module_identifier)
result = unittest.TextTestRunner().run(suite) result = unittest.TextTestRunner().run(suite)
...@@ -84,7 +85,7 @@ class TestCodeExamples(unittest.TestCase): ...@@ -84,7 +85,7 @@ class TestCodeExamples(unittest.TestCase):
self.assertIs(result.failed, 0) self.assertIs(result.failed, 0)
def test_modeling_examples(self): def test_modeling_examples(self):
transformers_directory = "src/transformers" transformers_directory = Path("src/transformers")
files = "modeling" files = "modeling"
ignore_files = [ ignore_files = [
"modeling_ctrl.py", "modeling_ctrl.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment