Unverified Commit 042a6aa7 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Tokenizers: ability to load from model subfolder (#8586)



* <small>tiny typo</small>

* Tokenizers: ability to load from model subfolder

* use subfolder for local files as well

* Uniformize model shortcut name => model id

* from s3 => from huggingface.co
Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
parent 48395d6b
......@@ -403,10 +403,9 @@ TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
pretrained_model_name_or_path:
Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
......@@ -420,8 +419,8 @@ TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by suppyling the save directory.
- The model is loaded by suppyling a local directory as ``pretrained_model_name_or_path`` and a
......@@ -507,7 +506,7 @@ class TFAutoModel(object):
Examples::
>>> from transformers import AutoConfig, TFAutoModel
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = TFAutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModel.from_config(config)
"""
......@@ -533,7 +532,7 @@ class TFAutoModel(object):
>>> from transformers import AutoConfig, AutoModel
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModel.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -601,7 +600,7 @@ class TFAutoModelForPreTraining(object):
Examples::
>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForPreTraining.from_config(config)
"""
......@@ -627,7 +626,7 @@ class TFAutoModelForPreTraining(object):
>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -701,7 +700,7 @@ class TFAutoModelWithLMHead(object):
Examples::
>>> from transformers import AutoConfig, TFAutoModelWithLMHead
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelWithLMHead.from_config(config)
"""
......@@ -733,7 +732,7 @@ class TFAutoModelWithLMHead(object):
>>> from transformers import AutoConfig, TFAutoModelWithLMHead
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -808,7 +807,7 @@ class TFAutoModelForCausalLM:
Examples::
>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('gpt2')
>>> model = TFAutoModelForCausalLM.from_config(config)
"""
......@@ -834,7 +833,7 @@ class TFAutoModelForCausalLM:
>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForCausalLM.from_pretrained('gpt2')
>>> # Update configuration during loading
......@@ -902,7 +901,7 @@ class TFAutoModelForMaskedLM:
Examples::
>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForMaskedLM.from_config(config)
"""
......@@ -928,7 +927,7 @@ class TFAutoModelForMaskedLM:
>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedLM.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -996,7 +995,7 @@ class TFAutoModelForSeq2SeqLM:
Examples::
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('t5')
>>> model = TFAutoModelForSeq2SeqLM.from_config(config)
"""
......@@ -1024,7 +1023,7 @@ class TFAutoModelForSeq2SeqLM:
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained('t5-base')
>>> # Update configuration during loading
......@@ -1094,7 +1093,7 @@ class TFAutoModelForSequenceClassification(object):
Examples::
>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForSequenceClassification.from_config(config)
"""
......@@ -1122,7 +1121,7 @@ class TFAutoModelForSequenceClassification(object):
>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -1191,7 +1190,7 @@ class TFAutoModelForQuestionAnswering(object):
Examples::
>>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForQuestionAnswering.from_config(config)
"""
......@@ -1219,7 +1218,7 @@ class TFAutoModelForQuestionAnswering(object):
>>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -1288,7 +1287,7 @@ class TFAutoModelForTokenClassification:
Examples::
>>> from transformers import AutoConfig, TFAutoModelForTokenClassification
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForTokenClassification.from_config(config)
"""
......@@ -1316,7 +1315,7 @@ class TFAutoModelForTokenClassification:
>>> from transformers import AutoConfig, TFAutoModelForTokenClassification
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -1386,7 +1385,7 @@ class TFAutoModelForMultipleChoice:
Examples::
>>> from transformers import AutoConfig, TFAutoModelForMultipleChoice
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForMultipleChoice.from_config(config)
"""
......@@ -1414,7 +1413,7 @@ class TFAutoModelForMultipleChoice:
>>> from transformers import AutoConfig, TFAutoModelForMultipleChoice
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMultipleChoice.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......@@ -1484,7 +1483,7 @@ class TFAutoModelForNextSentencePrediction:
Examples::
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download configuration from S3 and cache.
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForNextSentencePrediction.from_config(config)
"""
......@@ -1512,7 +1511,7 @@ class TFAutoModelForNextSentencePrediction:
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download model and configuration from S3 and cache.
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
......
......@@ -250,10 +250,9 @@ class AutoTokenizer:
pretrained_model_name_or_path (:obj:`str`):
Can be either:
- A string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3,
e.g., ``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.,
``./my_model_directory/``.
......@@ -280,6 +279,9 @@ class AutoTokenizer:
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
subfolder (:obj:`str`, `optional`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
facebook/rag-token-base), specify it here.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`):
......@@ -291,10 +293,10 @@ class AutoTokenizer:
>>> from transformers import AutoTokenizer
>>> # Download vocabulary from S3 and cache.
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
>>> # Download vocabulary from S3 (user-uploaded) and cache.
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
......
......@@ -214,10 +214,9 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_pretrained_model_name_or_path (:obj: `str`, `optional`):
Information necessary to initiate the encoder. Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
......@@ -228,10 +227,9 @@ class EncoderDecoderModel(PreTrainedModel):
decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
Information necessary to initiate the decoder. Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
......
......@@ -24,7 +24,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
# to pretrained vocabulary URL for all the model ids.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
......@@ -33,13 +33,13 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
####################################################
# Mapping from model shortcut names to max length of inputs
# Mapping from model ids to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"unc-nlp/lxmert-base-uncased": 512,
}
####################################################
# Mapping from model shortcut names to a dictionary of additional
# Mapping from model ids to a dictionary of additional
# keyword arguments for Tokenizer `__init__`.
# To be used for checkpoint specific configurations.
####################################################
......
......@@ -25,7 +25,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
# to pretrained vocabulary URL for all the model ids.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
......@@ -37,13 +37,13 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
####################################################
# Mapping from model shortcut names to max length of inputs
# Mapping from model ids to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"unc-nlp/lxmert-base-uncased": 512,
}
####################################################
# Mapping from model shortcut names to a dictionary of additional
# Mapping from model ids to a dictionary of additional
# keyword arguments for Tokenizer `__init__`.
# To be used for checkpoint specific configurations.
####################################################
......
......@@ -238,10 +238,9 @@ class RagPreTrainedModel(PreTrainedModel):
question_encoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
Information necessary to initiate the question encoder. Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
......@@ -252,10 +251,9 @@ class RagPreTrainedModel(PreTrainedModel):
generator_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
Information necessary to initiate the generator. Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
......
......@@ -49,10 +49,12 @@ class RagTokenizer:
if config is None:
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer")
generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer")
question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder)
generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator)
question_encoder = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, config=config.question_encoder, subfolder="question_encoder_tokenizer"
)
generator = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, config=config.generator, subfolder="generator_tokenizer"
)
return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs):
......
......@@ -38,7 +38,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
# to pretrained vocabulary URL for all the model ids.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
......@@ -47,7 +47,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
####################################################
# Mapping from model shortcut names to max length of inputs
# Mapping from model ids to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/reformer-crime-and-punishment": 524288,
......
......@@ -43,7 +43,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
# to pretrained vocabulary URL for all the model ids.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
......@@ -55,7 +55,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
####################################################
# Mapping from model shortcut names to max length of inputs
# Mapping from model ids to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/reformer-crime-and-punishment": 524288,
......
......@@ -50,7 +50,7 @@ _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
####################################################
# This dict contains shortcut names and associated url
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
......
......@@ -39,7 +39,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
# to pretrained vocabulary URL for all the model ids.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
......@@ -52,7 +52,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
####################################################
# Mapping from model shortcut names to max length of inputs
# Mapping from model ids to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-small": 512,
......
......@@ -42,7 +42,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
# to pretrained vocabulary URL for all the model ids.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
......@@ -62,7 +62,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
####################################################
# Mapping from model shortcut names to max length of inputs
# Mapping from model ids to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-small": 512,
......
......@@ -1615,10 +1615,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
pretrained_model_name_or_path (:obj:`str`):
Can be either:
- A string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A string, the `model id` of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a
user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing vocabulary files required by the tokenizer, for instance saved
using the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`
method, e.g., ``./my_model_directory/``.
......@@ -1641,6 +1640,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
subfolder (:obj:`str`, `optional`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
facebook/rag-token-base), specify it here.
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__`` method.
kwargs (additional keyword arguments, `optional`):
......@@ -1651,10 +1653,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
Examples::
# We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer
# Download vocabulary from S3 and cache.
# Download vocabulary from huggingface.co and cache.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Download vocabulary from S3 (user-uploaded) and cache.
# Download vocabulary from huggingface.co (user-uploaded) and cache.
tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
# If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
......@@ -1676,6 +1678,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
......@@ -1722,13 +1725,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Look for the tokenizer files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if subfolder is not None:
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else:
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, revision=revision, mirror=None
pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)
vocab_files[file_id] = full_file_name
......
......@@ -75,7 +75,7 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}
)
use_fast_tokenizer: bool = field(
default=True,
......@@ -98,7 +98,7 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}
)
use_fast_tokenizer: bool = field(
default=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment