Unverified Commit 8c276b9c authored by Stefan Schweter's avatar Stefan Schweter Committed by GitHub
Browse files

Merge branch 'master' into distilbert-german

parents da06afaf 3c28a2da
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ALBERT model configuration """
from .configuration_utils import PretrainedConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
}
class AlbertConfig(PretrainedConfig):
"""Configuration for `AlbertModel`.
The default settings match the configuration of model `albert_xxlarge`.
"""
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30000,
embedding_size=128,
hidden_size=4096,
num_hidden_layers=12,
num_hidden_groups=1,
num_attention_heads=64,
intermediate_size=16384,
inner_group_num=1,
hidden_act="gelu_new",
hidden_dropout_prob=0,
attention_probs_dropout_prob=0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12, **kwargs):
"""Constructs AlbertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `AlbertModel`.
embedding_size: size of voc embeddings.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
inner_group_num: int, number of inner repetition of attention and ffn.
down_scale_factor: float, the scale to apply
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`AlbertModel`.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
super(AlbertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_hidden_groups = num_hidden_groups
self.num_attention_heads = num_attention_heads
self.inner_group_num = inner_group_num
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
\ No newline at end of file
...@@ -95,6 +95,9 @@ class AutoConfig(object): ...@@ -95,6 +95,9 @@ class AutoConfig(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -29,6 +29,7 @@ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -29,6 +29,7 @@ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json", 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json", 'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
} }
......
...@@ -94,6 +94,9 @@ class PretrainedConfig(object): ...@@ -94,6 +94,9 @@ class PretrainedConfig(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -120,6 +123,7 @@ class PretrainedConfig(object): ...@@ -120,6 +123,7 @@ class PretrainedConfig(object):
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
...@@ -131,7 +135,8 @@ class PretrainedConfig(object): ...@@ -131,7 +135,8 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import torch
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = AlbertConfig.from_json_file(albert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = AlbertForMaskedLM(config)
# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--albert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.albert_config_file,
args.pytorch_dump_path)
\ No newline at end of file
...@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model, ...@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -46,7 +47,8 @@ if is_torch_available(): ...@@ -46,7 +47,8 @@ if is_torch_available():
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -56,7 +58,8 @@ else: ...@@ -56,7 +58,8 @@ else:
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = ( CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) = (
None, None, None, None, None, None, None, None,
None, None, None, None,
None, None, None, None,
...@@ -65,6 +68,7 @@ else: ...@@ -65,6 +68,7 @@ else:
None, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None,
None, None,
None, None) None, None)
...@@ -85,7 +89,8 @@ MODEL_CLASSES = { ...@@ -85,7 +89,8 @@ MODEL_CLASSES = {
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
......
from .processors import InputExample, InputFeatures, DataProcessor from .processors import InputExample, InputFeatures, DataProcessor
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
from .metrics import is_sklearn_available from .metrics import is_sklearn_available
if is_sklearn_available(): if is_sklearn_available():
from .metrics import glue_compute_metrics from .metrics import glue_compute_metrics, xnli_compute_metrics
...@@ -81,3 +81,11 @@ if _has_sklearn: ...@@ -81,3 +81,11 @@ if _has_sklearn:
return {"acc": simple_accuracy(preds, labels)} return {"acc": simple_accuracy(preds, labels)}
else: else:
raise KeyError(task_name) raise KeyError(task_name)
def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "xnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
from .utils import InputExample, InputFeatures, DataProcessor from .utils import InputExample, InputFeatures, DataProcessor
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" XNLI utils (dataset loading and evaluation) """
from __future__ import absolute_import, division, print_function
import logging
import os
from .utils import DataProcessor, InputExample
logger = logging.getLogger(__name__)
class XnliProcessor(DataProcessor):
"""Processor for the XNLI dataset.
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
def __init__(self, language, train_language = None):
self.language = language
self.train_language = train_language
def get_train_examples(self, data_dir):
"""See base class."""
lg = self.language if self.train_language is None else self.train_language
lines = self._read_tsv(os.path.join(data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(lg)))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % ('train', i)
text_a = line[0]
text_b = line[1]
label = "contradiction" if line[2] == "contradictory" else line[2]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
language = line[0]
if language != self.language:
continue
guid = "%s-%s" % ('test', i)
text_a = line[6]
text_b = line[7]
label = line[1]
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
xnli_processors = {
"xnli": XnliProcessor,
}
xnli_output_modes = {
"xnli": "classification",
}
xnli_tasks_num_labels = {
"xnli": 3,
}
...@@ -22,6 +22,7 @@ from botocore.config import Config ...@@ -22,6 +22,7 @@ from botocore.config import Config
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
...@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
Args: Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir. force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
...@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if parsed.scheme in ('http', 'https', 's3'): if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies,
resume_download=resume_download)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
...@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None): ...@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None): def http_get(url, temp_file, proxies=None, resume_size=0):
req = requests.get(url, stream=True, proxies=proxies) headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None
content_length = req.headers.get('Content-Length') response = requests.get(url, stream=True, proxies=proxies, headers=headers)
total = int(content_length) if content_length is not None else None if response.status_code == 416: # Range not satisfiable
progress = tqdm(unit="B", total=total) return
for chunk in req.iter_content(chunk_size=1024): content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, initial=resume_size)
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10): def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
if resume_download:
incomplete_path = cache_path + '.incomplete'
@contextmanager
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f:
yield f
os.remove(incomplete_path)
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = tempfile.NamedTemporaryFile
resume_size = 0
if not os.path.exists(cache_path) or force_download: if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object # GET file object
if url.startswith("s3://"): if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies) s3_get(url, temp_file, proxies=proxies)
else: else:
http_get(url, temp_file, proxies=proxies) http_get(url, temp_file, proxies=proxies, resume_size=resume_size)
# we are copying the file before closing it, so flush to avoid truncation # we are copying the file before closing it, so flush to avoid truncation
temp_file.flush() temp_file.flush()
......
This diff is collapsed.
...@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi ...@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -48,6 +49,7 @@ class AutoModel(object): ...@@ -48,6 +49,7 @@ class AutoModel(object):
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model) - contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model) - contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model) - contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
...@@ -71,6 +73,7 @@ class AutoModel(object): ...@@ -71,6 +73,7 @@ class AutoModel(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model) - contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model) - contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model) - contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
...@@ -112,6 +115,9 @@ class AutoModel(object): ...@@ -112,6 +115,9 @@ class AutoModel(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -138,6 +144,8 @@ class AutoModel(object): ...@@ -138,6 +144,8 @@ class AutoModel(object):
""" """
if 'distilbert' in pretrained_model_name_or_path: if 'distilbert' in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
...@@ -172,6 +180,7 @@ class AutoModelWithLMHead(object): ...@@ -172,6 +180,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model) - contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model) - contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
...@@ -198,6 +207,7 @@ class AutoModelWithLMHead(object): ...@@ -198,6 +207,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model) - contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model) - contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
...@@ -237,6 +247,8 @@ class AutoModelWithLMHead(object): ...@@ -237,6 +247,8 @@ class AutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
...@@ -264,6 +276,8 @@ class AutoModelWithLMHead(object): ...@@ -264,6 +276,8 @@ class AutoModelWithLMHead(object):
""" """
if 'distilbert' in pretrained_model_name_or_path: if 'distilbert' in pretrained_model_name_or_path:
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
...@@ -298,6 +312,7 @@ class AutoModelForSequenceClassification(object): ...@@ -298,6 +312,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model) - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model) - contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model) - contains `bert`: BertForSequenceClassification (Bert model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model) - contains `xlnet`: XLNetForSequenceClassification (XLNet model)
...@@ -320,6 +335,7 @@ class AutoModelForSequenceClassification(object): ...@@ -320,6 +335,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model) - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model) - contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model) - contains `bert`: BertForSequenceClassification (Bert model)
- contains `xlnet`: XLNetForSequenceClassification (XLNet model) - contains `xlnet`: XLNetForSequenceClassification (XLNet model)
...@@ -357,6 +373,9 @@ class AutoModelForSequenceClassification(object): ...@@ -357,6 +373,9 @@ class AutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -383,6 +402,8 @@ class AutoModelForSequenceClassification(object): ...@@ -383,6 +402,8 @@ class AutoModelForSequenceClassification(object):
""" """
if 'distilbert' in pretrained_model_name_or_path: if 'distilbert' in pretrained_model_name_or_path:
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
......
...@@ -278,7 +278,7 @@ class BertAttention(nn.Module): ...@@ -278,7 +278,7 @@ class BertAttention(nn.Module):
if len(heads) == 0: if len(heads) == 0:
return return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
for head in heads: for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly # Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads) head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
...@@ -597,7 +597,7 @@ class BertModel(BertPreTrainedModel): ...@@ -597,7 +597,7 @@ class BertModel(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
...@@ -760,7 +760,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -760,7 +760,7 @@ class BertForPreTraining(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased') model = BertForPreTraining.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
prediction_scores, seq_relationship_scores = outputs[:2] prediction_scores, seq_relationship_scores = outputs[:2]
...@@ -836,7 +836,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -836,7 +836,7 @@ class BertForMaskedLM(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids) outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2] loss, prediction_scores = outputs[:2]
...@@ -919,7 +919,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -919,7 +919,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
seq_relationship_scores = outputs[0] seq_relationship_scores = outputs[0]
...@@ -984,7 +984,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -984,7 +984,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2] loss, logits = outputs[:2]
...@@ -1060,7 +1060,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1060,7 +1060,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMultipleChoice.from_pretrained('bert-base-uncased') model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1 labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2] loss, classification_scores = outputs[:2]
...@@ -1134,7 +1134,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1134,7 +1134,7 @@ class BertForTokenClassification(BertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased') model = BertForTokenClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2] loss, scores = outputs[:2]
......
...@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function, ...@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
import logging import logging
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
from .configuration_camembert import CamembertConfig from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice): ...@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
""" """
config_class = CamembertConfig config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""CamemBERT Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
class CamembertForTokenClassification(RobertaForTokenClassification):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForTokenClassification.from_pretrained('camembert-base')
input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -44,6 +44,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -44,6 +44,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin", 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin", 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-pytorch_model.bin", 'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-pytorch_model.bin",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-pytorch_model.bin",
} }
......
This diff is collapsed.
...@@ -109,6 +109,9 @@ class TFAutoModel(object): ...@@ -109,6 +109,9 @@ class TFAutoModel(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object): ...@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object): ...@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -191,6 +191,9 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -191,6 +191,9 @@ class TFPreTrainedModel(tf.keras.Model):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -216,6 +219,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -216,6 +219,7 @@ class TFPreTrainedModel(tf.keras.Model):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
from_pt = kwargs.pop('from_pt', False) from_pt = kwargs.pop('from_pt', False)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
# Load config # Load config
...@@ -224,6 +228,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -224,6 +228,7 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
**kwargs **kwargs
) )
else: else:
...@@ -251,7 +256,8 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -251,7 +256,8 @@ class TFPreTrainedModel(tf.keras.Model):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
resume_download=resume_download, proxies=proxies)
except EnvironmentError as e: except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error( logger.error(
......
...@@ -291,6 +291,9 @@ class PreTrainedModel(nn.Module): ...@@ -291,6 +291,9 @@ class PreTrainedModel(nn.Module):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -315,11 +318,16 @@ class PreTrainedModel(nn.Module): ...@@ -315,11 +318,16 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.")
config = kwargs.pop('config', None) config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None) state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False) from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False) output_loading_info = kwargs.pop('output_loading_info', False)
...@@ -329,6 +337,7 @@ class PreTrainedModel(nn.Module): ...@@ -329,6 +337,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
proxies=proxies, proxies=proxies,
**kwargs **kwargs
) )
...@@ -361,7 +370,8 @@ class PreTrainedModel(nn.Module): ...@@ -361,7 +370,8 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format( msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment