Unverified Commit e4515faf authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1057 from huggingface/fixes

Add a few of typos corrections, bugs fixes and small improvements
parents 41789c6c 43489756
...@@ -127,4 +127,6 @@ proc_data ...@@ -127,4 +127,6 @@ proc_data
# examples # examples
runs runs
examples/runs examples/runs
\ No newline at end of file
data
\ No newline at end of file
...@@ -72,16 +72,16 @@ Here is the full list of the currently provided pretrained models together with ...@@ -72,16 +72,16 @@ Here is the full list of the currently provided pretrained models together with
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | | | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | XLNet Large English model | | | | | XLNet Large English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 1024-hidden, 8-heads | | XLM | ``xlm-mlm-en-2048`` | | 12-layer, 2048-hidden, 16-heads |
| | | | XLM English model | | | | | XLM English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads | | | ``xlm-mlm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads |
| | | | XLM English-German Multi-language model | | | | | XLM English-German Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads | | | ``xlm-mlm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads |
| | | | XLM English-French Multi-language model | | | | | XLM English-French Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-enro-1024`` | | 12-layer, 1024-hidden, 8-heads | | | ``xlm-mlm-enro-1024`` | | 6-layer, 1024-hidden, 8-heads |
| | | | XLM English-Romanian Multi-language model | | | | | XLM English-Romanian Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads | | | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
...@@ -93,7 +93,7 @@ Here is the full list of the currently provided pretrained models together with ...@@ -93,7 +93,7 @@ Here is the full list of the currently provided pretrained models together with
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads | | | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English model trained with CLM (Causal Language Modeling) | | | | | XLM English model trained with CLM (Causal Language Modeling) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-clm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads | | | ``xlm-clm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads |
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) | | | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters | | RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
......
...@@ -17,8 +17,9 @@ from hashlib import sha256 ...@@ -17,8 +17,9 @@ from hashlib import sha256
from io import open from io import open
import boto3 import boto3
import requests from botocore.config import Config
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import requests
from tqdm import tqdm from tqdm import tqdm
try: try:
...@@ -93,12 +94,15 @@ def filename_to_url(filename, cache_dir=None): ...@@ -93,12 +94,15 @@ def filename_to_url(filename, cache_dir=None):
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None): def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path, return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path. make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_TRANSFORMERS_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
...@@ -111,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None): ...@@ -111,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None):
if parsed.scheme in ('http', 'https', 's3'): if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir) return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
...@@ -156,24 +160,24 @@ def s3_request(func): ...@@ -156,24 +160,24 @@ def s3_request(func):
@s3_request @s3_request
def s3_etag(url): def s3_etag(url, proxies=None):
"""Check ETag on S3 object.""" """Check ETag on S3 object."""
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path) s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag return s3_object.e_tag
@s3_request @s3_request
def s3_get(url, temp_file): def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3.""" """Pull a file directly from S3."""
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file): def http_get(url, temp_file, proxies=None):
req = requests.get(url, stream=True) req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get('Content-Length') content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total) progress = tqdm(unit="B", total=total)
...@@ -184,7 +188,7 @@ def http_get(url, temp_file): ...@@ -184,7 +188,7 @@ def http_get(url, temp_file):
progress.close() progress.close()
def get_from_cache(url, cache_dir=None): def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -201,10 +205,10 @@ def get_from_cache(url, cache_dir=None): ...@@ -201,10 +205,10 @@ def get_from_cache(url, cache_dir=None):
# Get eTag to add to filename, if it exists. # Get eTag to add to filename, if it exists.
if url.startswith("s3://"): if url.startswith("s3://"):
etag = s3_etag(url) etag = s3_etag(url, proxies=proxies)
else: else:
try: try:
response = requests.head(url, allow_redirects=True) response = requests.head(url, allow_redirects=True, proxies=proxies)
if response.status_code != 200: if response.status_code != 200:
etag = None etag = None
else: else:
...@@ -227,17 +231,17 @@ def get_from_cache(url, cache_dir=None): ...@@ -227,17 +231,17 @@ def get_from_cache(url, cache_dir=None):
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path): if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object # GET file object
if url.startswith("s3://"): if url.startswith("s3://"):
s3_get(url, temp_file) s3_get(url, temp_file, proxies=proxies)
else: else:
http_get(url, temp_file) http_get(url, temp_file, proxies=proxies)
# we are copying the file before closing it, so flush to avoid truncation # we are copying the file before closing it, so flush to avoid truncation
temp_file.flush() temp_file.flush()
......
...@@ -599,7 +599,10 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -599,7 +599,10 @@ BERT_INPUTS_DOCSTRING = r"""
``tokens: [CLS] the dog is hairy . [SEP]`` ``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0`` ``token_type_ids: 0 0 0 0 0 0 0``
Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
......
...@@ -390,6 +390,8 @@ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in ...@@ -390,6 +390,8 @@ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
GPT2_INPUTS_DOCSTRING = r""" Inputs: GPT2_INPUTS_DOCSTRING = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`. Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
......
...@@ -404,6 +404,8 @@ OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in ...@@ -404,6 +404,8 @@ OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs: OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`. Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
......
...@@ -110,6 +110,10 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -110,6 +110,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
the ``add_special_tokens`` parameter set to ``True``. the ``add_special_tokens`` parameter set to ``True``.
RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......
...@@ -936,6 +936,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r""" ...@@ -936,6 +936,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
Inputs: Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`. Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
......
...@@ -125,6 +125,13 @@ class PretrainedConfig(object): ...@@ -125,6 +125,13 @@ class PretrainedConfig(object):
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
return_unused_kwargs: (`optional`) bool: return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final configuration object. - If False, then this function returns just the final configuration object.
...@@ -146,6 +153,8 @@ class PretrainedConfig(object): ...@@ -146,6 +153,8 @@ class PretrainedConfig(object):
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
...@@ -156,7 +165,7 @@ class PretrainedConfig(object): ...@@ -156,7 +165,7 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir) resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error( logger.error(
...@@ -400,6 +409,13 @@ class PreTrainedModel(nn.Module): ...@@ -400,6 +409,13 @@ class PreTrainedModel(nn.Module):
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used. configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean: output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
...@@ -424,6 +440,8 @@ class PreTrainedModel(nn.Module): ...@@ -424,6 +440,8 @@ class PreTrainedModel(nn.Module):
state_dict = kwargs.pop('state_dict', None) state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False) from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False) output_loading_info = kwargs.pop('output_loading_info', False)
# Load config # Load config
...@@ -431,6 +449,7 @@ class PreTrainedModel(nn.Module): ...@@ -431,6 +449,7 @@ class PreTrainedModel(nn.Module):
config, model_kwargs = cls.config_class.from_pretrained( config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download,
**kwargs **kwargs
) )
else: else:
...@@ -453,7 +472,7 @@ class PreTrainedModel(nn.Module): ...@@ -453,7 +472,7 @@ class PreTrainedModel(nn.Module):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error( logger.error(
......
...@@ -424,6 +424,10 @@ XLM_INPUTS_DOCSTRING = r""" ...@@ -424,6 +424,10 @@ XLM_INPUTS_DOCSTRING = r"""
Inputs: Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
XLM is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`. Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
...@@ -436,8 +440,10 @@ XLM_INPUTS_DOCSTRING = r""" ...@@ -436,8 +440,10 @@ XLM_INPUTS_DOCSTRING = r"""
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens to be used to indicate the language of each token in the input. A parallel sequence of tokens to be used to indicate the language of each token in the input.
Indices are selected in the pre-trained language vocabulary, Indices are languages ids which can be obtained from the language names by using two conversion mappings
i.e. in the range ``[0, config.n_langs - 1[``. provided in the configuration of the model (only provided for multilingual models).
More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
......
...@@ -655,6 +655,8 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -655,6 +655,8 @@ XLNET_INPUTS_DOCSTRING = r"""
Inputs: Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
XLNet is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`. Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
......
...@@ -187,6 +187,8 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -187,6 +187,8 @@ class BertTokenizer(PreTrainedTokenizer):
index = 0 index = 0
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
else:
vocab_file = vocab_path
with open(vocab_file, "w", encoding="utf-8") as writer: with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
......
...@@ -193,6 +193,13 @@ class PreTrainedTokenizer(object): ...@@ -193,6 +193,13 @@ class PreTrainedTokenizer(object):
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
...@@ -223,6 +230,8 @@ class PreTrainedTokenizer(object): ...@@ -223,6 +230,8 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
...@@ -283,7 +292,7 @@ class PreTrainedTokenizer(object): ...@@ -283,7 +292,7 @@ class PreTrainedTokenizer(object):
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.") logger.error("Couldn't reach server to download vocabulary.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment