Unverified Commit eeb70cdd authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into saving-and-resuming

parents 6aa91946 ed9b8481
...@@ -35,7 +35,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -35,7 +35,7 @@ class OpenAIGPTConfig(PretrainedConfig):
Configuration class to store the configuration of a `OpenAIGPTModel`. Configuration class to store the configuration of a `OpenAIGPTModel`.
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. vocab_size: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_positions: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions). n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
...@@ -58,7 +58,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -58,7 +58,7 @@ class OpenAIGPTConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size_or_config_json_file=40478, vocab_size=40478,
n_positions=512, n_positions=512,
n_ctx=512, n_ctx=512,
n_embd=768, n_embd=768,
...@@ -71,8 +71,6 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -71,8 +71,6 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True, predict_special_tokens=True,
num_labels=1,
summary_type='cls_index', summary_type='cls_index',
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
...@@ -83,15 +81,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -83,15 +81,7 @@ class OpenAIGPTConfig(PretrainedConfig):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
""" """
super(OpenAIGPTConfig, self).__init__(**kwargs) super(OpenAIGPTConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
...@@ -104,18 +94,11 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -104,18 +94,11 @@ class OpenAIGPTConfig(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
self.num_labels = num_labels
self.summary_type = summary_type self.summary_type = summary_type
self.summary_use_proj = summary_use_proj self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
......
# coding=utf-8
# Copyright 2010, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" T5 model configuration """
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
import six
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
}
class T5Config(PretrainedConfig):
r"""
:class:`~transformers.T5Config` is the configuration class to store the configuration of a
`T5Model`.
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `T5Model`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`T5Model`.
initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing).
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size=32128,
n_positions=512,
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_heads=8,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
**kwargs):
super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
@property
def max_position_embeddings(self):
return self.n_positions
@property
def hidden_size(self):
return self.d_model
@property
def num_attention_heads(self):
return self.num_heads
@property
def num_hidden_layers(self):
return self.num_layers
...@@ -34,7 +34,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -34,7 +34,7 @@ class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. vocab_size: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
cutoffs: cutoffs for the adaptive softmax cutoffs: cutoffs for the adaptive softmax
d_model: Dimensionality of the model's hidden states. d_model: Dimensionality of the model's hidden states.
d_embed: Dimensionality of the embeddings d_embed: Dimensionality of the embeddings
...@@ -68,7 +68,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -68,7 +68,7 @@ class TransfoXLConfig(PretrainedConfig):
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=267735, vocab_size=267735,
cutoffs=[20000, 40000, 200000], cutoffs=[20000, 40000, 200000],
d_model=1024, d_model=1024,
d_embed=1024, d_embed=1024,
...@@ -100,7 +100,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -100,7 +100,7 @@ class TransfoXLConfig(PretrainedConfig):
"""Constructs TransfoXLConfig. """Constructs TransfoXLConfig.
""" """
super(TransfoXLConfig, self).__init__(**kwargs) super(TransfoXLConfig, self).__init__(**kwargs)
self.n_token = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 self.vocab_size = vocab_size
self.cutoffs = [] self.cutoffs = []
self.cutoffs.extend(cutoffs) self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight self.tie_weight = tie_weight
...@@ -133,27 +133,17 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -133,27 +133,17 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len return self.tgt_len + self.ext_len + self.mem_len
@property @property
def vocab_size(self): def n_token(self): # Backward compatibility
return self.n_token return self.vocab_size
@vocab_size.setter @n_token.setter
def vocab_size(self, value): def n_token(self, value): # Backward compatibility
self.n_token = value self.vocab_size = value
@property @property
def hidden_size(self): def hidden_size(self):
......
...@@ -24,7 +24,7 @@ import logging ...@@ -24,7 +24,7 @@ import logging
import os import os
from io import open from io import open
from .file_utils import cached_path, CONFIG_NAME from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -49,16 +49,47 @@ class PretrainedConfig(object): ...@@ -49,16 +49,47 @@ class PretrainedConfig(object):
pretrained_config_archive_map = {} pretrained_config_archive_map = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.finetuning_task = kwargs.pop('finetuning_task', None) # Attributes with defaults
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models self.output_past = kwargs.pop('output_past', True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation
self.max_length = kwargs.pop('max_length', 20)
self.do_sample = kwargs.pop('do_sample', False)
self.num_beams = kwargs.pop('num_beams', 1)
self.temperature = kwargs.pop('temperature', 1.0)
self.top_k = kwargs.pop('top_k', 50)
self.top_p = kwargs.pop('top_p', 1.0)
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
self.bos_token_id = kwargs.pop('bos_token_id', 0)
self.pad_token_id = kwargs.pop('pad_token_id', 0)
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
self.length_penalty = kwargs.pop('length_penalty', 1.)
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
# Fine-tuning task arguments
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
...@@ -79,6 +110,7 @@ class PretrainedConfig(object): ...@@ -79,6 +110,7 @@ class PretrainedConfig(object):
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
...@@ -131,12 +163,18 @@ class PretrainedConfig(object): ...@@ -131,12 +163,18 @@ class PretrainedConfig(object):
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else: elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary else:
config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME)
try: try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download) proxies=proxies, resume_download=resume_download)
# Load config
config = cls.from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
...@@ -150,15 +188,18 @@ class PretrainedConfig(object): ...@@ -150,15 +188,18 @@ class PretrainedConfig(object):
config_file, CONFIG_NAME) config_file, CONFIG_NAME)
raise EnvironmentError(msg) raise EnvironmentError(msg)
except json.JSONDecodeError:
msg = "Couldn't reach server at '{}' to download configuration file or " \
"configuration file is not a valid JSON file. " \
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
raise EnvironmentError(msg)
if resolved_config_file == config_file: if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file)) logger.info("loading configuration file {}".format(config_file))
else: else:
logger.info("loading configuration file {} from cache at {}".format( logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file)) config_file, resolved_config_file))
# Load config
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'): if hasattr(config, 'pruned_heads'):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
...@@ -180,17 +221,15 @@ class PretrainedConfig(object): ...@@ -180,17 +221,15 @@ class PretrainedConfig(object):
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `Config` from a Python dictionary of parameters.""" """Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1) return cls(**json_object)
for key, value in json_object.items():
setattr(config, key, value)
return config
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters.""" """Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader: with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read() text = reader.read()
return cls.from_dict(json.loads(text)) dict_obj = json.loads(text)
return cls(**dict_obj)
def __eq__(self, other): def __eq__(self, other):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
......
...@@ -42,7 +42,7 @@ class XLMConfig(PretrainedConfig): ...@@ -42,7 +42,7 @@ class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`. """Configuration class to store the configuration of a `XLMModel`.
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`. vocab_size: Vocabulary size of `inputs_ids` in `XLMModel`.
d_model: Size of the encoder layers and the pooler layer. d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder. n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
...@@ -81,7 +81,7 @@ class XLMConfig(PretrainedConfig): ...@@ -81,7 +81,7 @@ class XLMConfig(PretrainedConfig):
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30145, vocab_size=30145,
emb_dim=2048, emb_dim=2048,
n_layers=12, n_layers=12,
n_heads=16, n_heads=16,
...@@ -103,9 +103,6 @@ class XLMConfig(PretrainedConfig): ...@@ -103,9 +103,6 @@ class XLMConfig(PretrainedConfig):
unk_index=3, unk_index=3,
mask_index=5, mask_index=5,
is_encoder=True, is_encoder=True,
finetuning_task=None,
num_labels=2,
summary_type='first', summary_type='first',
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
...@@ -113,19 +110,13 @@ class XLMConfig(PretrainedConfig): ...@@ -113,19 +110,13 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout=0.1, summary_first_dropout=0.1,
start_n_top=5, start_n_top=5,
end_n_top=5, end_n_top=5,
mask_token_id=0,
lang_id=0,
**kwargs): **kwargs):
"""Constructs XLMConfig. """Constructs XLMConfig.
""" """
super(XLMConfig, self).__init__(**kwargs) super(XLMConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_words = vocab_size_or_config_json_file
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.n_layers = n_layers self.n_layers = n_layers
self.n_heads = n_heads self.n_heads = n_heads
...@@ -147,8 +138,6 @@ class XLMConfig(PretrainedConfig): ...@@ -147,8 +138,6 @@ class XLMConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.embed_init_std = embed_init_std self.embed_init_std = embed_init_std
self.init_std = init_std self.init_std = init_std
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type self.summary_type = summary_type
self.summary_use_proj = summary_use_proj self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation self.summary_activation = summary_activation
...@@ -156,17 +145,19 @@ class XLMConfig(PretrainedConfig): ...@@ -156,17 +145,19 @@ class XLMConfig(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top self.start_n_top = start_n_top
self.end_n_top = end_n_top self.end_n_top = end_n_top
else: self.mask_token_id = mask_token_id
raise ValueError("First argument must be either a vocabulary size (int)" self.lang_id = lang_id
" or the path to a pretrained model config file (str)")
if "n_words" in kwargs:
self.n_words = kwargs["n_words"]
@property @property
def vocab_size(self): def n_words(self): # For backward compatibility
return self.n_words return self.vocab_size
@vocab_size.setter @n_words.setter
def vocab_size(self, value): def n_words(self, value): # For backward compatibility
self.n_words = value self.vocab_size = value
@property @property
def hidden_size(self): def hidden_size(self):
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" XLM-RoBERTa configuration """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
}
class XLMRobertaConfig(RobertaConfig):
pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
...@@ -35,7 +35,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -35,7 +35,7 @@ class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a ``XLNetModel``. """Configuration class to store the configuration of a ``XLNetModel``.
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``. vocab_size: Vocabulary size of ``inputs_ids`` in ``XLNetModel``.
d_model: Size of the encoder layers and the pooler layer. d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder. n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
...@@ -72,28 +72,22 @@ class XLNetConfig(PretrainedConfig): ...@@ -72,28 +72,22 @@ class XLNetConfig(PretrainedConfig):
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=32000, vocab_size=32000,
d_model=1024, d_model=1024,
n_layer=24, n_layer=24,
n_head=16, n_head=16,
d_inner=4096, d_inner=4096,
max_position_embeddings=512,
ff_activation="gelu", ff_activation="gelu",
untie_r=True, untie_r=True,
attn_type="bi", attn_type="bi",
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
dropout=0.1, dropout=0.1,
mem_len=None, mem_len=None,
reuse_len=None, reuse_len=None,
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False, same_length=False,
finetuning_task=None,
num_labels=2,
summary_type='last', summary_type='last',
summary_use_proj=True, summary_use_proj=True,
summary_activation='tanh', summary_activation='tanh',
...@@ -104,15 +98,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -104,15 +98,7 @@ class XLNetConfig(PretrainedConfig):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
""" """
super(XLNetConfig, self).__init__(**kwargs) super(XLNetConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
setattr(config, key, value)
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.d_model = d_model self.d_model = d_model
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
...@@ -133,29 +119,24 @@ class XLNetConfig(PretrainedConfig): ...@@ -133,29 +119,24 @@ class XLNetConfig(PretrainedConfig):
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.same_length = same_length self.same_length = same_length
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type self.summary_type = summary_type
self.summary_use_proj = summary_use_proj self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout self.summary_last_dropout = summary_last_dropout
self.start_n_top = start_n_top self.start_n_top = start_n_top
self.end_n_top = end_n_top self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return -1 return -1
@property @property
def vocab_size(self): def n_token(self): # Backward compatibility
return self.n_token return self.vocab_size
@vocab_size.setter @n_token.setter
def vocab_size(self, value): def n_token(self, value): # Backward compatibility
self.n_token = value self.vocab_size = value
@property @property
def hidden_size(self): def hidden_size(self):
......
...@@ -32,9 +32,10 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model, ...@@ -32,9 +32,10 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -46,9 +47,10 @@ if is_torch_available(): ...@@ -46,9 +47,10 @@ if is_torch_available():
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -57,9 +59,10 @@ else: ...@@ -57,9 +59,10 @@ else:
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) = ( AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
None, None, None, None, None, None, None, None,
None, None, None, None,
None, None, None, None,
...@@ -67,7 +70,8 @@ else: ...@@ -67,7 +70,8 @@ else:
None, None, None, None,
None, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None, None,
None, None,
None, None, None, None,
None, None) None, None)
...@@ -89,8 +93,10 @@ MODEL_CLASSES = { ...@@ -89,8 +93,10 @@ MODEL_CLASSES = {
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
...@@ -115,24 +121,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -115,24 +121,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu') state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None, pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
config=config, config=config,
state_dict=state_dict) state_dict=state_dict)
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad(): with torch.no_grad():
pto = pt_model(pt_inputs) pto = pt_model(**pt_model.dummy_inputs)
np_pt = pto[0].detach().numpy() np_pt = pto[0].numpy()
np_tf = tfo[0].numpy() np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf)) diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff)) print("Max absolute difference between models outputs {}".format(diff))
assert diff <= 2e-2, "Error, model absolute difference is >2e-2" assert diff <= 2e-2, "Error, model absolute difference is >2e-2: {}".format(diff)
# Save pytorch-model # Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path)) print("Save TensorFlow model to {}".format(tf_dump_path))
......
...@@ -20,6 +20,13 @@ import argparse ...@@ -20,6 +20,13 @@ import argparse
import logging import logging
import numpy as np import numpy as np
import torch import torch
import pathlib
import fairseq
from packaging import version
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
...@@ -45,8 +52,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -45,8 +52,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
""" """
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout roberta.eval() # disable dropout
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
config = BertConfig( config = BertConfig(
vocab_size_or_config_json_file=50265, vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
hidden_size=roberta.args.encoder_embed_dim, hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers, num_hidden_layers=roberta.args.encoder_layers,
num_attention_heads=roberta.args.encoder_attention_heads, num_attention_heads=roberta.args.encoder_attention_heads,
...@@ -64,7 +72,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -64,7 +72,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Now let's copy all the weights. # Now let's copy all the weights.
# Embeddings # Embeddings
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them. model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
...@@ -79,15 +86,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -79,15 +86,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### self attention ### self attention
self_attn: BertSelfAttention = layer.attention.self self_attn: BertSelfAttention = layer.attention.self
assert( assert(
roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size)) roberta_layer.self_attn.k_proj.weight.data.shape == \
roberta_layer.self_attn.q_proj.weight.data.shape == \
roberta_layer.self_attn.v_proj.weight.data.shape == \
torch.Size((config.hidden_size, config.hidden_size))
) )
# we use three distinct linear layers so we split the source layer here.
self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :] self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size] self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias
self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :] self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight
self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size] self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias
self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :] self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:] self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output ### self-attention output
self_output: BertSelfOutput = layer.attention.output self_output: BertSelfOutput = layer.attention.output
...@@ -151,6 +161,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ...@@ -151,6 +161,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if not success: if not success:
raise Exception("Something went wRoNg") raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}") print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path) model.save_pretrained(pytorch_dump_folder_path)
......
# coding=utf-8
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5 checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = T5Model(config)
# Load weights from tf checkpoint
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained T5 model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.config_file,
args.pytorch_dump_path)
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
......
from .utils import InputExample, InputFeatures, DataProcessor from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
\ No newline at end of file
...@@ -18,19 +18,20 @@ if is_tf_available(): ...@@ -18,19 +18,20 @@ if is_tf_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text): def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer.""" """Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1): for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1): for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
if text_span == tok_answer_text: if text_span == tok_answer_text:
return (new_start, new_end) return (new_start, new_end)
return (input_start, input_end) return (input_start, input_end)
def _check_is_max_context(doc_spans, cur_span_index, position): def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token.""" """Check if this is the 'max context' doc span for the token."""
best_score = None best_score = None
...@@ -50,6 +51,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -50,6 +51,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
def _new_check_is_max_context(doc_spans, cur_span_index, position): def _new_check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token.""" """Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1: # if len(doc_spans) == 1:
...@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position): ...@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
def _is_whitespace(c): def _is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True return True
return False return False
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training, def squad_convert_examples_to_features(
return_dataset=False): examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False
):
""" """
Converts a list of examples into a list of features that can be directly given as input to a model. Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...@@ -116,20 +120,19 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -116,20 +120,19 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id = 1000000000 unique_id = 1000000000
features = [] features = []
for (example_index, example) in enumerate(tqdm(examples)): for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")):
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
# Get start and end position # Get start and end position
start_position = example.start_position start_position = example.start_position
end_position = example.end_position end_position = example.end_position
# If the answer cannot be found in the text, then skip this example. # If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)]) actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
continue continue
tok_to_orig_index = [] tok_to_orig_index = []
orig_to_tok_index = [] orig_to_tok_index = []
all_doc_tokens = [] all_doc_tokens = []
...@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token) all_doc_tokens.append(sub_token)
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position] tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1: if example.end_position < len(example.doc_tokens) - 1:
...@@ -154,7 +156,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -154,7 +156,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
spans = [] spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) truncated_query = tokenizer.encode(
example.question_text, add_special_tokens=False, max_length=max_query_length
)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
...@@ -168,15 +172,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -168,15 +172,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return_overflowing_tokens=True, return_overflowing_tokens=True,
pad_to_max_length=True, pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first' truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
) )
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens) paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
if tokenizer.pad_token_id in encoded_dict['input_ids']: if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else: else:
non_padded_ids = encoded_dict['input_ids'] non_padded_ids = encoded_dict["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
...@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
for doc_span_index in range(len(spans)): for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]): for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if tokenizer.padding_side == "left" else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j index = (
j
if tokenizer.padding_side == "left"
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
)
spans[doc_span_index]["token_is_max_context"][index] = is_max_context spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans: for span in spans:
# Identify the position of the CLS token # Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id) cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...) # Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids']) p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1) p_mask = np.minimum(p_mask, 1)
...@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Set the CLS index to '0' # Set the CLS index to '0'
p_mask[cls_index] = 0 p_mask[cls_index] = 0
span_is_impossible = example.is_impossible span_is_impossible = example.is_impossible
start_position = 0 start_position = 0
end_position = 0 end_position = 0
...@@ -251,51 +261,95 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -251,51 +261,95 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
features.append(
features.append(SquadFeatures( SquadFeatures(
span['input_ids'], span["input_ids"],
span['attention_mask'], span["attention_mask"],
span['token_type_ids'], span["token_type_ids"],
cls_index, cls_index,
p_mask.tolist(), p_mask.tolist(),
example_index=example_index, example_index=example_index,
unique_id=unique_id, unique_id=unique_id,
paragraph_len=span['paragraph_len'], paragraph_len=span["paragraph_len"],
token_is_max_context=span["token_is_max_context"], token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"], tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"], token_to_orig_map=span["token_to_orig_map"],
start_position=start_position, start_position=start_position,
end_position=end_position end_position=end_position,
)) )
)
unique_id += 1 unique_id += 1
if return_dataset == 'pt': if return_dataset == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.") raise ImportError("Pytorch must be installed to return a pytorch dataset.")
# Convert to Tensors and build dataset # Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if not is_training: if not is_training:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_example_index, all_cls_index, all_p_mask) all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
)
else: else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_start_positions, all_end_positions, all_input_ids,
all_cls_index, all_p_mask) all_attention_masks,
all_token_type_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
return features, dataset return features, dataset
elif return_dataset == "tf":
if not is_tf_available():
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.")
def gen():
for ex in features:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
}, {
"start_position": ex.start_position,
"end_position": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
}
)
return tf.data.Dataset.from_generator(
gen,
(
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32},
),
(
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
{
"start_position": tf.TensorShape([]),
"end_position": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
},
),
)
return features return features
...@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor): ...@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor):
Processor for the SQuAD data set. Processor for the SQuAD data set.
Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
""" """
train_file = None train_file = None
dev_file = None dev_file = None
def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
if not evaluate: if not evaluate:
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8') answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
answer_start = tensor_dict['answers']['answer_start'][0].numpy() answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
answers = [] answers = []
else: else:
answers = [{ answers = [
"answer_start": start.numpy(), {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
"text": text.numpy().decode('utf-8') for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
} for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])] ]
answer = None answer = None
answer_start = None answer_start = None
return SquadExample( return SquadExample(
qas_id=tensor_dict['id'].numpy().decode("utf-8"), qas_id=tensor_dict["id"].numpy().decode("utf-8"),
question_text=tensor_dict['question'].numpy().decode('utf-8'), question_text=tensor_dict["question"].numpy().decode("utf-8"),
context_text=tensor_dict['context'].numpy().decode('utf-8'), context_text=tensor_dict["context"].numpy().decode("utf-8"),
answer_text=answer, answer_text=answer,
start_position_character=answer_start, start_position_character=answer_start,
title=tensor_dict['title'].numpy().decode('utf-8'), title=tensor_dict["title"].numpy().decode("utf-8"),
answers=answers answers=answers,
) )
def get_examples_from_dataset(self, dataset, evaluate=False): def get_examples_from_dataset(self, dataset, evaluate=False):
...@@ -373,10 +428,15 @@ class SquadProcessor(DataProcessor): ...@@ -373,10 +428,15 @@ class SquadProcessor(DataProcessor):
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
""" """
if data_dir is None:
data_dir = ""
if self.train_file is None: if self.train_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding='utf-8') as reader: with open(
os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
) as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
return self._create_examples(input_data, "train") return self._create_examples(input_data, "train")
...@@ -389,10 +449,15 @@ class SquadProcessor(DataProcessor): ...@@ -389,10 +449,15 @@ class SquadProcessor(DataProcessor):
filename: None by default, specify this if the evaluation file has a different name than the original one filename: None by default, specify this if the evaluation file has a different name than the original one
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
""" """
if data_dir is None:
data_dir = ""
if self.dev_file is None: if self.dev_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader: with open(
os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
) as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
return self._create_examples(input_data, "dev") return self._create_examples(input_data, "dev")
...@@ -400,7 +465,7 @@ class SquadProcessor(DataProcessor): ...@@ -400,7 +465,7 @@ class SquadProcessor(DataProcessor):
is_training = set_type == "train" is_training = set_type == "train"
examples = [] examples = []
for entry in tqdm(input_data): for entry in tqdm(input_data):
title = entry['title'] title = entry["title"]
for paragraph in entry["paragraphs"]: for paragraph in entry["paragraphs"]:
context_text = paragraph["context"] context_text = paragraph["context"]
for qa in paragraph["qas"]: for qa in paragraph["qas"]:
...@@ -418,8 +483,8 @@ class SquadProcessor(DataProcessor): ...@@ -418,8 +483,8 @@ class SquadProcessor(DataProcessor):
if not is_impossible: if not is_impossible:
if is_training: if is_training:
answer = qa["answers"][0] answer = qa["answers"][0]
answer_text = answer['text'] answer_text = answer["text"]
start_position_character = answer['answer_start'] start_position_character = answer["answer_start"]
else: else:
answers = qa["answers"] answers = qa["answers"]
...@@ -431,12 +496,13 @@ class SquadProcessor(DataProcessor): ...@@ -431,12 +496,13 @@ class SquadProcessor(DataProcessor):
start_position_character=start_position_character, start_position_character=start_position_character,
title=title, title=title,
is_impossible=is_impossible, is_impossible=is_impossible,
answers=answers answers=answers,
) )
examples.append(example) examples.append(example)
return examples return examples
class SquadV1Processor(SquadProcessor): class SquadV1Processor(SquadProcessor):
train_file = "train-v1.1.json" train_file = "train-v1.1.json"
dev_file = "dev-v1.1.json" dev_file = "dev-v1.1.json"
...@@ -462,7 +528,8 @@ class SquadExample(object): ...@@ -462,7 +528,8 @@ class SquadExample(object):
is_impossible: False by default, set to True if the example has no possible answer. is_impossible: False by default, set to True if the example has no possible answer.
""" """
def __init__(self, def __init__(
self,
qas_id, qas_id,
question_text, question_text,
context_text, context_text,
...@@ -470,7 +537,8 @@ class SquadExample(object): ...@@ -470,7 +537,8 @@ class SquadExample(object):
start_position_character, start_position_character,
title, title,
answers=[], answers=[],
is_impossible=False): is_impossible=False,
):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
self.context_text = context_text self.context_text = context_text
...@@ -503,7 +571,9 @@ class SquadExample(object): ...@@ -503,7 +571,9 @@ class SquadExample(object):
# Start end end positions only has a value during evaluation. # Start end end positions only has a value during evaluation.
if start_position_character is not None and not is_impossible: if start_position_character is not None and not is_impossible:
self.start_position = char_to_word_offset[start_position_character] self.start_position = char_to_word_offset[start_position_character]
self.end_position = char_to_word_offset[start_position_character + len(answer_text) - 1] self.end_position = char_to_word_offset[
min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
]
class SquadFeatures(object): class SquadFeatures(object):
...@@ -531,22 +601,21 @@ class SquadFeatures(object): ...@@ -531,22 +601,21 @@ class SquadFeatures(object):
end_position: end of the answer token index end_position: end of the answer token index
""" """
def __init__(self, def __init__(
self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids, token_type_ids,
cls_index, cls_index,
p_mask, p_mask,
example_index, example_index,
unique_id, unique_id,
paragraph_len, paragraph_len,
token_is_max_context, token_is_max_context,
tokens, tokens,
token_to_orig_map, token_to_orig_map,
start_position, start_position,
end_position end_position,
): ):
self.input_ids = input_ids self.input_ids = input_ids
self.attention_mask = attention_mask self.attention_mask = attention_mask
...@@ -574,6 +643,7 @@ class SquadResult(object): ...@@ -574,6 +643,7 @@ class SquadResult(object):
start_logits: The logits corresponding to the start of the answer start_logits: The logits corresponding to the start of the answer
end_logits: The logits corresponding to the end of the answer end_logits: The logits corresponding to the end of the answer
""" """
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
self.start_logits = start_logits self.start_logits = start_logits
self.end_logits = end_logits self.end_logits = end_logits
......
...@@ -18,6 +18,11 @@ import csv ...@@ -18,6 +18,11 @@ import csv
import sys import sys
import copy import copy
import json import json
import logging
from ...file_utils import is_tf_available, is_torch_available
logger = logging.getLogger(__name__)
class InputExample(object): class InputExample(object):
""" """
...@@ -64,7 +69,7 @@ class InputFeatures(object): ...@@ -64,7 +69,7 @@ class InputFeatures(object):
label: Label corresponding to the input label: Label corresponding to the input
""" """
def __init__(self, input_ids, attention_mask, token_type_ids, label): def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
self.input_ids = input_ids self.input_ids = input_ids
self.attention_mask = attention_mask self.attention_mask = attention_mask
self.token_type_ids = token_type_ids self.token_type_ids = token_type_ids
...@@ -86,34 +91,6 @@ class InputFeatures(object): ...@@ -86,34 +91,6 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def get_example_from_tensor_dict(self, tensor_dict):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise NotImplementedError()
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def tfds_map(self, example):
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
This method converts examples to the correct format."""
if len(self.get_labels()) > 1:
example.label = self.get_labels()[int(example.label)]
return example
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
...@@ -125,3 +102,215 @@ class DataProcessor(object): ...@@ -125,3 +102,215 @@ class DataProcessor(object):
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line) lines.append(line)
return lines return lines
class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples
self.mode = mode
self.verbose = verbose
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
if isinstance(idx, slice):
return SingleSentenceClassificationProcessor(labels=self.labels,
examples=self.examples[idx])
return self.examples[idx]
@classmethod
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
column_id=None, skip_first_row=False, **kwargs):
processor = cls(**kwargs)
processor.add_examples_from_csv(file_name,
split_name=split_name,
column_label=column_label,
column_text=column_text,
column_id=column_id,
skip_first_row=skip_first_row,
overwrite_labels=True,
overwrite_examples=True)
return processor
@classmethod
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
processor = cls(**kwargs)
processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
lines = self._read_tsv(file_name)
if skip_first_row:
lines = lines[1:]
texts = []
labels = []
ids = []
for (i, line) in enumerate(lines):
texts.append(line[column_text])
labels.append(line[column_label])
if column_id is not None:
ids.append(line[column_id])
else:
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
ids.append(guid)
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
overwrite_labels=False, overwrite_examples=False):
assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None:
ids = [None] * len(texts_or_text_and_labels)
if labels is None:
labels = [None] * len(texts_or_text_and_labels)
examples = []
added_labels = set()
for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
text, label = text_or_text_and_label
else:
text = text_or_text_and_label
added_labels.add(label)
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
# Update examples
if overwrite_examples:
self.examples = examples
else:
self.examples.extend(examples)
# Update labels
if overwrite_labels:
self.labels = list(added_labels)
else:
self.labels = list(set(self.labels).union(added_labels))
return self.examples
def get_features(self,
tokenizer,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
return_tensors=None):
"""
Convert examples in a list of ``InputFeatures``
Args:
tokenizer: Instance of a tokenizer that will tokenize the examples
max_length: Maximum example length
task: GLUE task
label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
output_mode: String indicating the output mode. Either ``regression`` or ``classification``
pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
pad_token: Padding token
mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
actual values)
Returns:
If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
containing the task-specific features. If the input is a list of ``InputExamples``, will return
a list of task-specific ``InputFeatures`` which can be fed to the model.
"""
if max_length is None:
max_length = tokenizer.max_len
label_map = {label: i for i, label in enumerate(self.labels)}
all_input_ids = []
for (ex_index, example) in enumerate(self.examples):
if ex_index % 10000 == 0:
logger.info("Tokenizing example %d", ex_index)
input_ids = tokenizer.encode(
example.text_a,
add_special_tokens=True,
max_length=min(max_length, tokenizer.max_len),
)
all_input_ids.append(input_ids)
batch_length = max(len(input_ids) for input_ids in all_input_ids)
features = []
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
if ex_index % 10000 == 0:
logger.info("Writing example %d", ex_index)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = batch_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
else:
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length)
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length)
if self.mode == "classification":
label = label_map[example.label]
elif self.mode == "regression":
label = float(example.label)
else:
raise ValueError(self.mode)
if ex_index < 5 and self.verbose:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("label: %s (id = %d)" % (example.label, label))
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
label=label))
if return_tensors is None:
return features
elif return_tensors == 'tf':
if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf
def gen():
for ex in features:
yield ({'input_ids': ex.input_ids,
'attention_mask': ex.attention_mask},
ex.label)
dataset = tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32,
'attention_mask': tf.int32},
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None])},
tf.TensorShape([])))
return dataset
elif return_tensors == 'pt':
if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch
from torch.utils.data import TensorDataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
if self.mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif self.mode == "regression":
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
return dataset
else:
raise ValueError("return_tensors should be one of 'tf' or 'pt'")
...@@ -10,10 +10,9 @@ import json ...@@ -10,10 +10,9 @@ import json
import logging import logging
import os import os
import six import six
import shutil
import tempfile import tempfile
import fnmatch import fnmatch
from functools import wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from io import open from io import open
...@@ -23,24 +22,36 @@ from botocore.exceptions import ClientError ...@@ -23,24 +22,36 @@ from botocore.exceptions import ClientError
import requests import requests
from tqdm.auto import tqdm from tqdm.auto import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name from filelock import FileLock
try: logger = logging.getLogger(__name__) # pylint: disable=invalid-name
import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TORCH', 'YES')
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
else:
logger.info("USE_TORCH override through env variable, disabling PyTorch")
_torch_available = False
except ImportError: except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try:
os.environ.setdefault('USE_TF', 'YES')
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
else:
logger.info("USE_TF override through env variable, disabling Tensorflow")
_tf_available = False
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try: try:
from torch.hub import _get_torch_home from torch.hub import _get_torch_home
...@@ -72,11 +83,20 @@ WEIGHTS_NAME = "pytorch_model.bin" ...@@ -72,11 +83,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5' TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json"
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
if not six.PY2: if not six.PY2:
...@@ -103,12 +123,25 @@ else: ...@@ -103,12 +123,25 @@ else:
return fn return fn
return docstring_decorator return docstring_decorator
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ('http', 'https', 's3')
def hf_bucket_url(identifier, postfix=None, cdn=False):
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
if postfix is None:
return "/".join((endpoint, identifier))
else:
return "/".join((endpoint, identifier, postfix))
def url_to_filename(url, etag=None): def url_to_filename(url, etag=None):
""" """
Convert `url` into a hashed filename in a repeatable way. Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited If `etag` is specified, append its hash to the url's, delimited
by a period. by a period.
If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
so that TF 2.0 can identify it as a HDF5 file so that TF 2.0 can identify it as a HDF5 file
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
""" """
...@@ -153,7 +186,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -153,7 +186,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False): def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -163,6 +196,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -163,6 +196,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir. force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found. resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
...@@ -171,17 +205,15 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -171,17 +205,15 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename) if is_remote_url(url_or_filename):
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies, force_download=force_download, proxies=proxies,
resume_download=resume_download) resume_download=resume_download, user_agent=user_agent)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
elif parsed.scheme == '': elif urlparse(url_or_filename).scheme == '':
# File, but it doesn't exist. # File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename)) raise EnvironmentError("file {} not found".format(url_or_filename))
else: else:
...@@ -238,14 +270,26 @@ def s3_get(url, temp_file, proxies=None): ...@@ -238,14 +270,26 @@ def s3_get(url, temp_file, proxies=None):
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None, resume_size=0): def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if isinstance(user_agent, dict):
ua += "; " + "; ".join(
"{}/{}".format(k, v) for k, v in user_agent.items()
)
elif isinstance(user_agent, six.string_types):
ua += "; "+ user_agent
headers = {
"user-agent": ua
}
if resume_size > 0:
headers['Range'] = 'bytes=%d-' % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers) response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable if response.status_code == 416: # Range not satisfiable
return return
content_length = response.headers.get('Content-Length') content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size, desc="Downloading") progress = tqdm(unit="B", unit_scale=True, total=total, initial=resume_size,
desc="Downloading", disable=bool(logger.level<=logging.INFO))
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
...@@ -253,7 +297,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0): ...@@ -253,7 +297,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0):
progress.close() progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False): def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -291,28 +335,34 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -291,28 +335,34 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# If we don't have a connection (etag is None) and can't identify the file # If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one # try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None: if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') matching_files = [
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
if not file.endswith('.json') and not file.endswith('.lock')
]
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock'
with FileLock(lock_path):
if resume_download: if resume_download:
incomplete_path = cache_path + '.incomplete' incomplete_path = cache_path + '.incomplete'
@contextmanager @contextmanager
def _resumable_file_manager(): def _resumable_file_manager():
with open(incomplete_path,'a+b') as f: with open(incomplete_path,'a+b') as f:
yield f yield f
os.remove(incomplete_path)
temp_file_manager = _resumable_file_manager temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path): if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size resume_size = os.stat(incomplete_path).st_size
else: else:
resume_size = 0 resume_size = 0
else: else:
temp_file_manager = tempfile.NamedTemporaryFile temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0 resume_size = 0
if not os.path.exists(cache_path) or force_download: if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
...@@ -324,16 +374,13 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -324,16 +374,13 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies) s3_get(url, temp_file, proxies=proxies)
else: else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size) http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
# we are copying the file before closing it, so flush to avoid truncation # we are copying the file before closing it, so flush to avoid truncation
temp_file.flush() temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path) logger.info("storing %s in cache at %s", url, cache_path)
with open(cache_path, 'wb') as cache_file: os.rename(temp_file.name, cache_path)
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {'url': url, 'etag': etag}
...@@ -344,6 +391,4 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -344,6 +391,4 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
output_string = unicode(output_string, 'utf-8') # The beauty of python 2 output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string) meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path return cache_path
...@@ -131,8 +131,9 @@ class HfApi: ...@@ -131,8 +131,9 @@ class HfApi:
# the client still has to specify it when uploading the file. # the client still has to specify it when uploading the file.
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
pf = TqdmProgressFileReader(f) pf = TqdmProgressFileReader(f)
data = f if pf.total_size > 0 else ""
r = requests.put(urls.write, data=f, headers={ r = requests.put(urls.write, data=data, headers={
"content-type": urls.type, "content-type": urls.type,
}) })
r.raise_for_status() r.raise_for_status()
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import copy
import json
import logging
import os
from io import open
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \
cached_path, is_remote_url, hf_bucket_url
logger = logging.getLogger(__name__)
class ModelCard(object):
r""" Model Card class.
Store model card as well as methods for loading/downloading/saving model cards.
Please read the following paper for details and explanation on the sections:
"Model Cards for Model Reporting"
by Margaret Mitchell, Simone Wu,
Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards.
Link: https://arxiv.org/abs/1810.03993
Note:
A model card can be loaded and saved to disk.
Parameters:
"""
def __init__(self, **kwargs):
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
self.model_details = kwargs.pop('model_details', {})
self.intended_use = kwargs.pop('intended_use', {})
self.factors = kwargs.pop('factors', {})
self.metrics = kwargs.pop('metrics', {})
self.evaluation_data = kwargs.pop('evaluation_data', {})
self.training_data = kwargs.pop('training_data', {})
self.quantitative_analyses = kwargs.pop('quantitative_analyses', {})
self.ethical_considerations = kwargs.pop('ethical_considerations', {})
self.caveats_and_recommendations = kwargs.pop('caveats_and_recommendations', {})
# Open additional attributes
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`.
"""
if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
else:
output_model_card_file = save_directory_or_file
self.to_json_file(output_model_card_file)
logger.info("Model card saved in {}".format(output_model_card_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
Parameters:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
card should be cached if the standard cache should not be used.
kwargs: (`optional`) dict: key/value pairs with which to update the ModelCard object after loading.
- The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
find_from_standard_name: (`optional`) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them with our standard modelcard filename.
Can be used to directly feed a model/config url and access the colocated modelcard.
return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final model card object.
- If True, then this functions returns a tuple `(model card, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of kwargs which has not been used to update `ModelCard` and is otherwise ignored.
Examples::
modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
"""
cache_dir = kwargs.pop('cache_dir', None)
proxies = kwargs.pop('proxies', None)
find_from_standard_name = kwargs.pop('find_from_standard_name', True)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True,
proxies=proxies, resume_download=False)
if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file))
else:
logger.info("loading model card file {} from cache at {}".format(
model_card_file, resolved_model_card_file))
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.warning("Couldn't reach server at '{}' to download model card file.".format(
model_card_file))
else:
logger.warning("Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to a model card file named {} or " \
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file, MODEL_CARD_NAME))
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
modelcard = cls()
except json.JSONDecodeError:
logger.warning("Couldn't reach server at '{}' to download model card file or "
"model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file))
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
modelcard = cls()
# Update model card with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(modelcard, key):
setattr(modelcard, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model card: %s", str(modelcard))
if return_unused_kwargs:
return modelcard, kwargs
else:
return modelcard
@classmethod
def from_dict(cls, json_object):
"""Constructs a `ModelCard` from a Python dictionary of parameters."""
return cls(**json_object)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
dict_obj = json.loads(text)
return cls(**dict_obj)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
This diff is collapsed.
...@@ -48,6 +48,12 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -48,6 +48,12 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin", 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin", 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment