Unverified Commit 7f20bf0d authored by Matt's avatar Matt Committed by GitHub
Browse files

Fixing requirements for TF LM models and use correct model mappings (#14372)

* Fixing requirements for TF LM models and use correct model mappings

* make style
parent 4c35c8d8
datasets >= 1.8.0
sentencepiece != 0.1.92
\ No newline at end of file
...@@ -43,8 +43,8 @@ import transformers ...@@ -43,8 +43,8 @@ import transformers
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
CONFIG_NAME, CONFIG_NAME,
MODEL_FOR_CAUSAL_LM_MAPPING,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
...@@ -57,8 +57,8 @@ from transformers.utils.versions import require_version ...@@ -57,8 +57,8 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# endregion # endregion
......
...@@ -45,8 +45,8 @@ import transformers ...@@ -45,8 +45,8 @@ import transformers
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
CONFIG_NAME, CONFIG_NAME,
MODEL_FOR_MASKED_LM_MAPPING,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
...@@ -59,8 +59,8 @@ from transformers.utils.versions import require_version ...@@ -59,8 +59,8 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment