Commit 59fe641b authored by thomwolf's avatar thomwolf
Browse files

also gathering file names in file_utils

parent d68a8fe4
...@@ -24,7 +24,7 @@ from .tokenization_roberta import RobertaTokenizer ...@@ -24,7 +24,7 @@ from .tokenization_roberta import RobertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
# Configurations # Configurations
from .configuration_utils import CONFIG_NAME, PretrainedConfig from .configuration_utils import PretrainedConfig
from .configuration_auto import AutoConfig from .configuration_auto import AutoConfig
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
...@@ -36,7 +36,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH ...@@ -36,7 +36,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling # Modeling
from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME, PreTrainedModel, prune_layer, Conv1D) from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead) AutoModelWithLMHead)
...@@ -70,4 +70,6 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa ...@@ -70,4 +70,6 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
# Files and general utilities # Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings) from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
...@@ -24,12 +24,10 @@ import logging ...@@ -24,12 +24,10 @@ import logging
import os import os
from io import open from io import open
from .file_utils import cached_path from .file_utils import cached_path, CONFIG_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CONFIG_NAME = "config.json"
class PretrainedConfig(object): class PretrainedConfig(object):
r""" Base class for all configuration classes. r""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
......
...@@ -48,6 +48,10 @@ except (AttributeError, ImportError): ...@@ -48,6 +48,10 @@ except (AttributeError, ImportError):
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
if not six.PY2: if not six.PY2:
......
...@@ -31,13 +31,10 @@ from torch.nn import CrossEntropyLoss ...@@ -31,13 +31,10 @@ from torch.nn import CrossEntropyLoss
from torch.nn import functional as F from torch.nn import functional as F
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
try: try:
from torch.nn import Identity from torch.nn import Identity
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment