Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
......@@ -24,9 +24,9 @@ import tensorflow as tf
from .modeling_tf_utils import shape_list
class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False, **kwargs):
def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs):
super(TFAdaptiveSoftmaxMask, self).__init__(**kwargs)
self.vocab_size = vocab_size
......@@ -47,52 +47,59 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
def build(self, input_shape):
if self.n_clusters > 0:
self.cluster_weight = self.add_weight(shape=(self.n_clusters, self.d_embed),
initializer='zeros',
trainable=True,
name='cluster_weight')
self.cluster_bias = self.add_weight(shape=(self.n_clusters,),
initializer='zeros',
trainable=True,
name='cluster_bias')
self.cluster_weight = self.add_weight(
shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight"
)
self.cluster_bias = self.add_weight(
shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias"
)
if self.div_val == 1:
for i in range(len(self.cutoffs)):
if self.d_proj != self.d_embed:
weight = self.add_weight(shape=(self.d_embed, self.d_proj),
initializer='zeros',
trainable=True,
name='out_projs_._{}'.format(i))
weight = self.add_weight(
shape=(self.d_embed, self.d_proj),
initializer="zeros",
trainable=True,
name="out_projs_._{}".format(i),
)
self.out_projs.append(weight)
else:
self.out_projs.append(None)
weight = self.add_weight(shape=(self.vocab_size, self.d_embed,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._weight'.format(i))
bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._bias'.format(i))
weight = self.add_weight(
shape=(self.vocab_size, self.d_embed,),
initializer="zeros",
trainable=True,
name="out_layers_._{}_._weight".format(i),
)
bias = self.add_weight(
shape=(self.vocab_size,),
initializer="zeros",
trainable=True,
name="out_layers_._{}_._bias".format(i),
)
self.out_layers.append((weight, bias))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = self.d_embed // (self.div_val ** i)
weight = self.add_weight(shape=(d_emb_i, self.d_proj),
initializer='zeros',
trainable=True,
name='out_projs_._{}'.format(i))
weight = self.add_weight(
shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name="out_projs_._{}".format(i)
)
self.out_projs.append(weight)
weight = self.add_weight(shape=(r_idx-l_idx, d_emb_i,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._weight'.format(i))
bias = self.add_weight(shape=(r_idx-l_idx,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._bias'.format(i))
weight = self.add_weight(
shape=(r_idx - l_idx, d_emb_i,),
initializer="zeros",
trainable=True,
name="out_layers_._{}_._weight".format(i),
)
bias = self.add_weight(
shape=(r_idx - l_idx,),
initializer="zeros",
trainable=True,
name="out_layers_._{}_._bias".format(i),
)
self.out_layers.append((weight, bias))
super(TFAdaptiveSoftmaxMask, self).build(input_shape)
......@@ -100,8 +107,8 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
def _logit(x, W, b, proj=None):
y = x
if proj is not None:
y = tf.einsum('ibd,ed->ibe', y, proj)
return tf.einsum('ibd,nd->ibn', y, W) + b
y = tf.einsum("ibd,ed->ibe", y, proj)
return tf.einsum("ibd,nd->ibn", y, W) + b
@staticmethod
def _gather_logprob(logprob, target):
......@@ -114,7 +121,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
hidden, target = inputs
head_logprob = 0
if self.n_clusters == 0:
softmax_b = tf.get_variable('bias', [self.config.vocab_size], initializer=tf.zeros_initializer())
softmax_b = tf.get_variable("bias", [self.config.vocab_size], initializer=tf.zeros_initializer())
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
if target is not None:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
......@@ -143,7 +150,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
head_logprob = tf.nn.log_softmax(head_logit)
out.append(head_logprob[..., :self.cutoffs[0]])
out.append(head_logprob[..., : self.cutoffs[0]])
if target is not None:
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
......@@ -170,6 +177,6 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
# Log the loss as a metric (we could log arbitrary metrics,
# including different metrics for training and inference.
self.add_metric(loss, name=self.name, aggregation='mean' if return_mean else '')
self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "")
return out
......@@ -15,8 +15,7 @@
# limitations under the License.
"""TF general model utils."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
......@@ -26,12 +25,20 @@ from tensorflow.python.keras.saving import hdf5_format
import h5py
from .configuration_utils import PretrainedConfig
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, DUMMY_INPUTS,
cached_path, hf_bucket_url, is_remote_url)
from .file_utils import (
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_NAME,
DUMMY_INPUTS,
cached_path,
hf_bucket_url,
is_remote_url,
)
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models.
......@@ -60,7 +67,7 @@ class TFPreTrainedModel(tf.keras.Model):
Returns:
tf.Tensor with dummy inputs
"""
return {'input_ids': tf.constant(DUMMY_INPUTS)}
return {"input_ids": tf.constant(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
......@@ -70,7 +77,8 @@ class TFPreTrainedModel(tf.keras.Model):
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
)
)
# Save config in model
self.config = config
......@@ -151,7 +159,9 @@ class TFPreTrainedModel(tf.keras.Model):
""" Save a model and its configuration file to a directory, so that it
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
assert os.path.isdir(
save_directory
), "Saving path should be a directory where the model and configuration can be saved"
# Save configuration file
self.config.save_pretrained(save_directory)
......@@ -230,20 +240,22 @@ class TFPreTrainedModel(tf.keras.Model):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)
"""
config = kwargs.pop('config', None)
cache_dir = kwargs.pop('cache_dir', None)
from_pt = kwargs.pop('from_pt', False)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
config_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
**kwargs
......@@ -263,9 +275,11 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
pretrained_model_name_or_path))
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_pt` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
)
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
......@@ -273,31 +287,37 @@ class TFPreTrainedModel(tf.keras.Model):
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
if from_pt:
raise EnvironmentError("Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name.")
raise EnvironmentError(
"Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name."
)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
resume_download=resume_download, proxies=proxies)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
)
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
)
)
raise e
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
else:
resolved_archive_file = None
......@@ -316,38 +336,42 @@ class TFPreTrainedModel(tf.keras.Model):
try:
model.load_weights(resolved_archive_file, by_name=True)
except OSError:
raise OSError("Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. ")
raise OSError(
"Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
# Check if the models are the same to output loading informations
with h5py.File(resolved_archive_file, 'r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, 'layer_names'))
with h5py.File(resolved_archive_file, "r") as f:
if "layer_names" not in f.attrs and "model_weights" in f:
f = f["model_weights"]
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_keys = list(model_layer_names - hdf5_layer_names)
unexpected_keys = list(hdf5_layer_names - model_layer_names)
error_msgs = []
if len(missing_keys) > 0:
logger.info("Layers of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
logger.info(
"Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info("Layers from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
logger.info(
"Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
)
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading weights for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
raise RuntimeError(
"Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
if output_loading_info:
loading_info = {"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs}
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
return model, loading_info
return model
class TFConv1D(tf.keras.layers.Layer):
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
......@@ -360,13 +384,9 @@ class TFConv1D(tf.keras.layers.Layer):
def build(self, input_shape):
self.weight = self.add_weight(
"weight",
shape=[self.nx, self.nf],
initializer=get_initializer(self.initializer_range))
self.bias = self.add_weight(
"bias",
shape=[1, self.nf],
initializer=tf.zeros_initializer())
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
)
self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
def call(self, x):
bz, sl = shape_list(x)[:2]
......@@ -382,11 +402,12 @@ class TFConv1D(tf.keras.layers.Layer):
class TFSharedEmbeddings(tf.keras.layers.Layer):
"""Construct shared token embeddings.
"""
def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
super(TFSharedEmbeddings, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
def build(self, input_shape):
"""Build shared word embedding layer
......@@ -394,9 +415,8 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self.weight = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range))
"weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
)
super(TFSharedEmbeddings, self).build(input_shape)
def call(self, inputs, mode="embedding"):
......@@ -455,35 +475,36 @@ class TFSequenceSummary(tf.keras.layers.Layer):
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
"""
def __init__(self, config, initializer_range=0.02, **kwargs):
super(TFSequenceSummary, self).__init__(**kwargs)
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
if self.summary_type == 'attn':
self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.has_summary = hasattr(config, 'summary_use_proj') and config.summary_use_proj
self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
if self.has_summary:
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = tf.keras.layers.Dense(num_classes,
kernel_initializer=get_initializer(initializer_range),
name='summary')
self.summary = tf.keras.layers.Dense(
num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
)
self.has_activation = hasattr(config, 'summary_activation') and config.summary_activation == 'tanh'
self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
if self.has_activation:
self.activation = tf.keras.activations.tanh
self.has_first_dropout = hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0
self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
if self.has_first_dropout:
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
self.has_last_dropout = hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
......@@ -502,29 +523,33 @@ class TFSequenceSummary(tf.keras.layers.Layer):
cls_index = inputs[1] if len(inputs) > 1 else None
assert len(inputs) <= 2, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
cls_index = inputs.get('cls_index', None)
input_ids = inputs.get("input_ids")
cls_index = inputs.get("cls_index", None)
if self.summary_type == 'last':
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == 'first':
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
elif self.summary_type == "mean":
output = tf.reduce_mean(hidden_states, axis=1)
elif self.summary_type == 'cls_index':
elif self.summary_type == "cls_index":
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
if cls_index is None:
cls_index = tf.fill(hidden_shape[:-2], hidden_shape[-2] - 1) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
cls_index = tf.fill(
hidden_shape[:-2], hidden_shape[-2] - 1
) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
cls_shape = shape_list(cls_index)
if len(cls_shape) <= len(hidden_shape) - 2:
cls_index = cls_index[..., tf.newaxis]
# else:
# cls_index = cls_index[..., tf.newaxis]
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
# cls_index = cls_index[..., tf.newaxis]
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
output = tf.squeeze(output, axis=len(hidden_shape) - 2) # shape of output: (batch, num choices, hidden_size)
elif self.summary_type == 'attn':
output = tf.squeeze(
output, axis=len(hidden_shape) - 2
) # shape of output: (batch, num choices, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
if self.has_first_dropout:
......@@ -541,12 +566,14 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range=0.02):
"""Creates a `tf.initializers.truncated_normal` with the given range.
Args:
......
......@@ -25,30 +25,34 @@ import numpy as np
import tensorflow as tf
from .configuration_xlm import XLMConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer, DUMMY_INPUTS
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
TFSequenceSummary,
shape_list,
get_initializer,
DUMMY_INPUTS,
)
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-tf_model.h5",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-tf_model.h5",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-tf_model.h5",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-tf_model.h5",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-tf_model.h5",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-tf_model.h5",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-tf_model.h5",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-tf_model.h5",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-tf_model.h5",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-tf_model.h5",
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-tf_model.h5",
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-tf_model.h5",
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-tf_model.h5",
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-tf_model.h5",
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-tf_model.h5",
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-tf_model.h5",
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-tf_model.h5",
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-tf_model.h5",
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-tf_model.h5",
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-tf_model.h5",
}
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
for pos in range(n_pos)
])
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out[:, 0::2] = tf.constant(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2]))
......@@ -78,8 +82,9 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:
attn_mask = tf.less_equal(tf.tile(alen[tf.newaxis, tf.newaxis, :], (bs, slen, 1)),
alen[tf.newaxis, :, tf.newaxis])
attn_mask = tf.less_equal(
tf.tile(alen[tf.newaxis, tf.newaxis, :], (bs, slen, 1)), alen[tf.newaxis, :, tf.newaxis]
)
else:
attn_mask = mask
......@@ -106,10 +111,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
self.n_heads = n_heads
assert self.dim % self.n_heads == 0
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name='q_lin')
self.k_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name='k_lin')
self.v_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name='v_lin')
self.out_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name='out_lin')
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
self.k_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin")
self.v_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin")
self.out_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.pruned_heads = set()
......@@ -125,7 +130,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = shape_list(input)
if kv is None:
klen = qlen if cache is None else cache['slen'] + qlen
klen = qlen if cache is None else cache["slen"] + qlen
else:
klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
......@@ -141,40 +146,40 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
""" compute context """
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
elif cache is None or self.layer_id not in cache:
k = v = kv
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
if cache is not None:
if self.layer_id in cache:
if kv is None:
k_, v_ = cache[self.layer_id]
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
scores = scores - 1e30 * (1.0 - mask)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),)
if self.output_attentions:
......@@ -183,11 +188,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
class TFTransformerFFN(tf.keras.layers.Layer):
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
super(TFTransformerFFN, self).__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name='lin1')
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name='lin2')
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = tf.keras.layers.Activation(gelu) if config.gelu_activation else tf.keras.activations.relu
self.dropout = tf.keras.layers.Dropout(config.dropout)
......@@ -226,30 +230,36 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# assert len(self.id2lang) == len(self.lang2id) == self.n_langs
# model parameters
self.dim = config.emb_dim # 512 by default
self.dim = config.emb_dim # 512 by default
self.hidden_dim = self.dim * 4 # 2048 by default
self.n_heads = config.n_heads # 8 by default
self.n_heads = config.n_heads # 8 by default
self.n_layers = config.n_layers
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
# embeddings
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name='position_embeddings')
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings",
)
if config.sinusoidal_embeddings:
raise NotImplementedError
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding(self.n_langs,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name='lang_embeddings')
self.embeddings = TFSharedEmbeddings(self.n_words, self.dim, initializer_range=config.embed_init_std, name='embeddings') # padding_idx=self.pad_index)
self.layer_norm_emb = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm_emb')
self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings",
)
self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
) # padding_idx=self.pad_index)
self.layer_norm_emb = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb")
# transformer layers
self.attentions = []
......@@ -261,13 +271,21 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self.encoder_attn = []
for i in range(self.n_layers):
self.attentions.append(TFMultiHeadAttention(self.n_heads, self.dim, config=config, name='attentions_._{}'.format(i)))
self.layer_norm1.append(tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm1_._{}'.format(i)))
self.attentions.append(
TFMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
)
self.layer_norm1.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
)
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name='ffns_._{}'.format(i)))
self.layer_norm2.append(tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm2_._{}'.format(i)))
self.ffns.append(
TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
)
self.layer_norm2.append(
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
)
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
......@@ -276,7 +294,6 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
def get_input_embeddings(self):
return self.embeddings
......@@ -290,9 +307,19 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None,
training=False): # removed: src_enc=None, src_len=None
def call(
self,
inputs,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
training=False,
): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
......@@ -305,15 +332,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask)
langs = inputs.get('langs', langs)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
lengths = inputs.get('lengths', lengths)
cache = inputs.get('cache', cache)
head_mask = inputs.get('head_mask', head_mask)
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......@@ -331,7 +358,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if input_ids is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
else:
lengths = tf.convert_to_tensor([slen]*bs, tf.int32)
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index
# check inputs
......@@ -375,7 +402,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache['slen']
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
......@@ -430,7 +457,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# update cache length
if cache is not None:
cache['slen'] += tensor.size(1)
cache["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
......@@ -447,6 +474,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLMConfig
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
......@@ -460,7 +488,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list}
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
XLM_START_DOCSTRING = r""" The XLM model was proposed in
......@@ -554,8 +582,12 @@ XLM_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class TFXLMModel(TFXLMPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -581,20 +613,21 @@ class TFXLMModel(TFXLMPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLMModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer')
self.transformer = TFXLMMainLayer(config, name="transformer")
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
class TFXLMPredLayer(tf.keras.layers.Layer):
"""
Prediction layer (cross_entropy or adaptive_softmax).
"""
def __init__(self, config, input_embeddings, **kwargs):
super(TFXLMPredLayer, self).__init__(**kwargs)
self.asm = config.asm
......@@ -614,10 +647,7 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
def build(self, input_shape):
# The output weights are the same as the input embeddings, but there is an output-only bias for each token.
self.bias = self.add_weight(shape=(self.n_words,),
initializer='zeros',
trainable=True,
name='bias')
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
super(TFXLMPredLayer, self).build(input_shape)
def call(self, hidden_states):
......@@ -626,9 +656,12 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
return hidden_states
@add_start_docstrings("""The XLM Model transformer with a language modeling head on top
@add_start_docstrings(
"""The XLM Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -654,10 +687,11 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer')
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
self.transformer = TFXLMMainLayer(config, name="transformer")
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
def get_output_embeddings(self):
return self.pred_layer.input_embeddings
......@@ -672,9 +706,12 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
return outputs
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings(
"""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -701,12 +738,13 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
logits = outputs[0]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLMForSequenceClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.transformer = TFXLMMainLayer(config, name='transformer')
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name='sequence_summary')
self.transformer = TFXLMMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
......@@ -718,9 +756,12 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
return outputs
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -748,12 +789,13 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
start_scores, end_scores = outputs[:2]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLMForQuestionAnsweringSimple, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(config.init_std),
name='qa_outputs')
self.transformer = TFXLMMainLayer(config, name="transformer")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
......@@ -765,6 +807,8 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
outputs = (start_logits, end_logits,) + transformer_outputs[
1:
] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (hidden_states), (attentions)
......@@ -35,8 +35,8 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-tf_model.h5",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-tf_model.h5",
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-tf_model.h5",
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-tf_model.h5",
}
......@@ -45,8 +45,7 @@ def gelu(x):
XLNet is using OpenAI GPT's gelu
Also see https://arxiv.org/abs/1606.08415
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
......@@ -54,9 +53,11 @@ def swish(x):
return x * tf.sigmoid(x)
ACT2FN = {"gelu": tf.keras.layers.Activation(gelu),
"relu": tf.keras.activations.relu,
"swish": tf.keras.layers.Activation(swish)}
ACT2FN = {
"gelu": tf.keras.layers.Activation(gelu),
"relu": tf.keras.activations.relu,
"swish": tf.keras.layers.Activation(swish),
}
class TFXLNetRelativeAttention(tf.keras.layers.Layer):
......@@ -67,7 +68,8 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if config.d_model % config.n_head != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.d_model, config.n_head))
"heads (%d)" % (config.d_model, config.n_head)
)
self.n_head = config.n_head
self.d_head = config.d_head
......@@ -75,38 +77,38 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
self.scale = 1 / (config.d_head ** 0.5)
self.initializer_range = config.initializer_range
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout)
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.q = self.add_weight(shape=(self.d_model, self.n_head, self.d_head),
initializer=initializer,
trainable=True, name='q')
self.k = self.add_weight(shape=(self.d_model, self.n_head, self.d_head),
initializer=initializer,
trainable=True, name='k')
self.v = self.add_weight(shape=(self.d_model, self.n_head, self.d_head),
initializer=initializer,
trainable=True, name='v')
self.o = self.add_weight(shape=(self.d_model, self.n_head, self.d_head),
initializer=initializer,
trainable=True, name='o')
self.r = self.add_weight(shape=(self.d_model, self.n_head, self.d_head),
initializer=initializer,
trainable=True, name='r')
self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head),
initializer='zeros',
trainable=True, name='r_r_bias')
self.r_s_bias = self.add_weight(shape=(self.n_head, self.d_head),
initializer='zeros',
trainable=True, name='r_s_bias')
self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head),
initializer='zeros',
trainable=True, name='r_w_bias')
self.seg_embed = self.add_weight(shape=(2, self.n_head, self.d_head),
initializer=initializer,
trainable=True, name='seg_embed')
self.q = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
)
self.k = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
)
self.v = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
)
self.o = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
)
self.r = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
)
self.r_r_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
)
self.r_s_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
)
self.r_w_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
)
self.seg_embed = self.add_weight(
shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super(TFXLNetRelativeAttention, self).build(input_shape)
def prune_heads(self, heads):
......@@ -130,18 +132,18 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask = inputs
# content based attention score
ac = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_w_bias, k_head_h)
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
# position based attention score
bd = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift(bd, klen=shape_list(ac)[1])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = tf.einsum('ibnd,snd->ibns', q_head + self.r_s_bias, self.seg_embed)
ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * self.scale
......@@ -162,7 +164,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_prob = attn_prob * head_mask
# attention output
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions:
return attn_vec, attn_prob
......@@ -174,7 +176,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# post-attention projection (back to `d_model`)
h, attn_vec = inputs
attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, self.o)
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training)
......@@ -185,8 +187,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
return output
def call(self, inputs, training=False):
(h, g, attn_mask_h, attn_mask_g,
r, seg_mat, mems, target_mapping, head_mask) = inputs
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
if g is not None:
###### Two-stream attention with relative positional encoding.
......@@ -197,22 +198,22 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
cat = h
# content-based key head
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.k)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
# content-based value head
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.v)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
# position-based key head
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.r)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream
# content-stream query head
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.q)
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
# core attention ops
attn_vec_h = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask],
training=training)
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], training=training
)
if self.output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
......@@ -222,23 +223,23 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
##### g-stream
# query-stream query head
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.q)
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
# core attention ops
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask],
training=training)
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], training=training
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask],
training=training)
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask], training=training
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
......@@ -257,17 +258,17 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
cat = h
# content heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.q)
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.k)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.v)
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
# positional heads
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.r)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
# core attention ops
attn_vec = self.rel_attn_core(
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask],
training=training)
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask], training=training
)
if self.output_attentions:
attn_vec, attn_prob = attn_vec
......@@ -281,19 +282,21 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
outputs = outputs + (attn_prob,)
return outputs
class TFXLNetFeedForward(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFXLNetFeedForward, self).__init__(**kwargs)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm')
self.layer_1 = tf.keras.layers.Dense(config.d_inner,
kernel_initializer=get_initializer(config.initializer_range),
name='layer_1')
self.layer_2 = tf.keras.layers.Dense(config.d_model,
kernel_initializer=get_initializer(config.initializer_range),
name='layer_2')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.layer_1 = tf.keras.layers.Dense(
config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1"
)
self.layer_2 = tf.keras.layers.Dense(
config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
if isinstance(config.ff_activation, str) or \
(sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)):
if isinstance(config.ff_activation, str) or (
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)
):
self.activation_function = ACT2FN[config.ff_activation]
else:
self.activation_function = config.ff_activation
......@@ -308,11 +311,12 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
output = self.layer_norm(output + inp)
return output
class TFXLNetLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFXLNetLayer, self).__init__(**kwargs)
self.rel_attn = TFXLNetRelativeAttention(config, name='rel_attn')
self.ff = TFXLNetFeedForward(config, name='ff')
self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn")
self.ff = TFXLNetFeedForward(config, name="ff")
self.dropout = tf.keras.layers.Dropout(config.dropout)
def call(self, inputs, training=False):
......@@ -336,10 +340,7 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='bias')
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super(TFXLNetLMHead, self).build(input_shape)
def call(self, hidden_states):
......@@ -366,8 +367,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.use_bfloat16 = config.use_bfloat16
self.initializer_range = config.initializer_range
self.word_embedding = TFSharedEmbeddings(config.vocab_size, config.d_model, initializer_range=config.initializer_range, name='word_embedding')
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
self.word_embedding = TFSharedEmbeddings(
config.vocab_size, config.d_model, initializer_range=config.initializer_range, name="word_embedding"
)
self.layer = [TFXLNetLayer(config, name="layer_._{}".format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_input_embeddings(self):
......@@ -375,9 +378,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(shape=(1, 1, self.d_model),
initializer=initializer,
trainable=True, name='mask_emb')
self.mask_emb = self.add_weight(
shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
)
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
......@@ -417,18 +420,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
curr_out = curr_out[: self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
new_mem = curr_out[-self.mem_len :]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len:]
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len :]
return tf.stop_gradient(new_mem)
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq)
sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
pos_emb = pos_emb[:, None, :]
......@@ -444,14 +447,14 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
freq_seq = tf.cast(freq_seq, dtype=dtype)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == 'bi':
if self.attn_type == "bi":
# beg, end = klen - 1, -qlen
beg, end = klen, -qlen
elif self.attn_type == 'uni':
elif self.attn_type == "uni":
# beg, end = klen - 1, -1
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
if self.bi_data:
fwd_pos_seq = tf.range(beg, end, -1.0)
......@@ -467,9 +470,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if bsz is not None:
# With bi_data, the batch size should be divisible by 2.
assert bsz%2 == 0
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
assert bsz % 2 == 0
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
......@@ -485,8 +488,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, training=False):
def call(
self,
inputs,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
......@@ -499,15 +513,15 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask)
mems = inputs.get('mems', mems)
perm_mask = inputs.get('perm_mask', perm_mask)
target_mapping = inputs.get('target_mapping', target_mapping)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
input_mask = inputs.get('input_mask', input_mask)
head_mask = inputs.get('head_mask', head_mask)
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......@@ -540,17 +554,19 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
elif self.attn_type == "bi":
attn_mask = None
else:
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
raise ValueError("Unsupported attention type: {}".format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \
assert input_mask is None or attention_mask is None, (
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
)
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
if input_mask is not None and perm_mask is not None:
......@@ -564,8 +580,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if data_mask is not None:
# all mems can be attended to
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz],
dtype=dtype_float)
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
......@@ -590,9 +605,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q, training=training)
else:
output_g = None
......@@ -604,9 +619,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(
tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])),
tf.int32)
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
else:
seg_mat = None
......@@ -626,7 +639,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layer
......@@ -643,9 +658,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module([output_h, output_g, non_tgt_mask, attn_mask,
pos_emb, seg_mat, mems[i], target_mapping,
head_mask[i]], training=training)
outputs = layer_module(
[output_h, output_g, non_tgt_mask, attn_mask, pos_emb, seg_mat, mems[i], target_mapping, head_mask[i]],
training=training,
)
output_h, output_g = outputs[:2]
if self.output_attentions:
attentions.append(outputs[2])
......@@ -679,6 +695,7 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLNetConfig
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
......@@ -784,8 +801,12 @@ XLNET_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class TFXLNetModel(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -816,18 +837,22 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.transformer = TFXLNetMainLayer(config, name="transformer")
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
@add_start_docstrings("""XLNet Model with a language modeling head on top
@add_start_docstrings(
"""XLNet Model with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -865,10 +890,11 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss")
def get_output_embeddings(self):
return self.lm_loss.input_embeddings
......@@ -883,9 +909,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
return outputs # return logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings(
"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -916,15 +945,18 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
logits = outputs[0]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetForSequenceClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.initializer_range, name='sequence_summary')
self.logits_proj = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name='logits_proj')
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(
config, initializer_range=config.initializer_range, name="sequence_summary"
)
self.logits_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
......@@ -938,9 +970,12 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
return outputs # return logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of
@add_start_docstrings(
"""XLNet Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -971,14 +1006,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
scores = outputs[0]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetForTokenClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name='classifier')
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
......@@ -1027,12 +1063,13 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
start_scores, end_scores = outputs[:2]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetForQuestionAnsweringSimple, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name='qa_outputs')
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
......@@ -1044,10 +1081,13 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
outputs = (start_logits, end_logits,) + transformer_outputs[
1:
] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
......
......@@ -42,65 +42,62 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map = {}
if hasattr(model, 'transformer'):
if hasattr(model, "transformer"):
# We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
tf_to_pt_map.update({
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers,
model.crit.out_projs,
config.tie_projs)):
tf_to_pt_map.update(
{
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias,
}
)
for i, (out_l, proj_l, tie_proj) in enumerate(
zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
):
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
if config.tie_weight:
tf_to_pt_map.update({
layer_str + 'b': out_l.bias})
tf_to_pt_map.update({layer_str + "b": out_l.bias})
else:
raise NotImplementedError
# I don't think this is implemented in the TF code
tf_to_pt_map.update({
layer_str + 'lookup_table': out_l.weight,
layer_str + 'b': out_l.bias})
tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias})
if not tie_proj:
tf_to_pt_map.update({
layer_str + 'proj': proj_l
})
tf_to_pt_map.update({layer_str + "proj": proj_l})
# Now load the rest of the transformer
model = model.transformer
# Embeddings
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
tf_to_pt_map.update({
layer_str + 'lookup_table': embed_l.weight,
layer_str + 'proj_W': proj_l
})
tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l})
# Transformer blocks
for i, b in enumerate(model.layers):
layer_str = "transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
})
tf_to_pt_map.update(
{
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
}
)
# Relative positioning biases
if config.untie_r:
......@@ -112,11 +109,10 @@ def build_tf_to_pytorch_map(model, config):
else:
r_r_list = [model.r_r_bias]
r_w_list = [model.r_w_bias]
tf_to_pt_map.update({
'transformer/r_r_bias': r_r_list,
'transformer/r_w_bias': r_w_list})
tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list})
return tf_to_pt_map
def load_tf_weights_in_transfo_xl(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
......@@ -124,8 +120,10 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
import numpy as np
import tensorflow as tf
except ImportError:
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
logger.error(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
......@@ -143,9 +141,9 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj' in name:
if "kernel" in name or "proj" in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
......@@ -166,10 +164,10 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
tf_weights.pop(name + "/Adam", None)
tf_weights.pop(name + "/Adam_1", None)
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
return model
......@@ -180,17 +178,16 @@ class PositionalEmbedding(nn.Module):
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1)
return pos_emb[:, None, :].expand(-1, bsz, -1)
else:
return pos_emb[:,None,:]
return pos_emb[:, None, :]
class PositionwiseFF(nn.Module):
......@@ -202,7 +199,8 @@ class PositionwiseFF(nn.Module):
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Linear(d_model, d_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
......@@ -230,10 +228,22 @@ class PositionwiseFF(nn.Module):
class RelPartialLearnableMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False,
layer_norm_epsilon=1e-5):
def __init__(
self,
n_head,
d_model,
d_head,
dropout,
dropatt=0,
tgt_len=None,
ext_len=None,
mem_len=None,
pre_lnorm=False,
r_r_bias=None,
r_w_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5,
):
super(RelPartialLearnableMultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
......@@ -254,7 +264,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
else:
......@@ -299,18 +309,18 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + self.r_r_bias
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
......@@ -319,21 +329,19 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
#### compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2:
if next(self.parameters()).dtype == torch.float16:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -65000).type_as(attn_score)
attn_score = (
attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3:
if next(self.parameters()).dtype == torch.float16:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -65000).type_as(attn_score)
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -1e30).type_as(attn_score)
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
......@@ -344,11 +352,10 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_prob = attn_prob * head_mask
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
......@@ -368,21 +375,19 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5,
**kwargs):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5, **kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'),
layer_norm_epsilon=layer_norm_epsilon)
self.dec_attn = RelPartialLearnableMultiHeadAttn(
n_head, d_model, d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs
)
self.pos_ff = PositionwiseFF(
d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon
)
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
attn_outputs = self.dec_attn(dec_inp, r,
attn_mask=dec_attn_mask,
mems=mems, head_mask=head_mask)
attn_outputs = self.dec_attn(dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
......@@ -391,8 +396,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token
......@@ -409,28 +413,25 @@ class AdaptiveEmbedding(nn.Module):
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
embed = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
......@@ -458,15 +459,16 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
def _init_weight(self, weight):
if self.config.init == 'uniform':
if self.config.init == "uniform":
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
elif self.config.init == 'normal':
elif self.config.init == "normal":
nn.init.normal_(weight, 0.0, self.config.init_std)
def _init_bias(self, bias):
......@@ -476,41 +478,41 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
""" Initialize the weights.
"""
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
if classname.find("Linear") != -1:
if hasattr(m, "weight") and m.weight is not None:
self._init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
if hasattr(m, "bias") and m.bias is not None:
self._init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
elif classname.find("AdaptiveEmbedding") != -1:
if hasattr(m, "emb_projs"):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
elif classname.find("Embedding") != -1:
if hasattr(m, "weight"):
self._init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
elif classname.find("ProjectedAdaptiveLogSoftmax") != -1:
if hasattr(m, "cluster_weight") and m.cluster_weight is not None:
self._init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
if hasattr(m, "cluster_bias") and m.cluster_bias is not None:
self._init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
if hasattr(m, "out_projs"):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
elif classname.find("LayerNorm") != -1:
if hasattr(m, "weight"):
nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None:
if hasattr(m, "bias") and m.bias is not None:
self._init_bias(m.bias)
else:
if hasattr(m, 'r_emb'):
if hasattr(m, "r_emb"):
self._init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
if hasattr(m, "r_w_bias"):
self._init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
if hasattr(m, "r_r_bias"):
self._init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
if hasattr(m, "r_bias"):
self._init_bias(m.r_bias)
......@@ -559,8 +561,12 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
TRANSFO_XL_START_DOCSTRING,
TRANSFO_XL_INPUTS_DOCSTRING,
)
class TransfoXLModel(TransfoXLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -587,6 +593,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
last_hidden_states, mems = outputs[:2]
"""
def __init__(self, config):
super(TransfoXLModel, self).__init__(config)
self.output_attentions = config.output_attentions
......@@ -599,8 +606,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.n_head = config.n_head
self.d_head = config.d_head
self.word_emb = AdaptiveEmbedding(config.vocab_size, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val)
self.word_emb = AdaptiveEmbedding(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
)
self.drop = nn.Dropout(config.dropout)
......@@ -618,27 +626,35 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.layers = nn.ModuleList()
if config.attn_type == 0: # the default attention
if config.attn_type == 0: # the default attention
for i in range(config.n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
config.n_head,
config.d_model,
config.d_head,
config.d_inner,
config.dropout,
tgt_len=config.tgt_len,
ext_len=config.ext_len,
mem_len=config.mem_len,
dropatt=config.dropatt,
pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon)
layer_norm_epsilon=config.layer_norm_epsilon,
)
)
else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
raise NotImplementedError # Removed them to avoid maintaining dead code
self.same_length = config.same_length
self.clamp_len = config.clamp_len
if self.attn_type == 0: # default attention
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
else: # learnable embeddings and absolute embeddings
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self.init_weights()
......@@ -666,8 +682,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems = []
param = next(self.parameters())
for i in range(self.n_layer):
empty = torch.zeros(self.mem_len, bsz, self.config.d_model,
dtype=param.dtype, device=param.device)
empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device)
mems.append(empty)
return mems
......@@ -676,10 +691,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None: return None
if mems is None:
return None
# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'
assert len(hids) == len(mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
......@@ -725,7 +741,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layer
......@@ -743,17 +761,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
:, :, None
]
hids = []
attentions = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.attn_type == 0: # default
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
......@@ -764,12 +781,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
layer_outputs = layer(
core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]
)
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out = self.drop(core_out)
......@@ -791,9 +809,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
@add_start_docstrings("""The Transformer-XL Model with a language modeling head on top
@add_start_docstrings(
"""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)""",
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
TRANSFO_XL_START_DOCSTRING,
TRANSFO_XL_INPUTS_DOCSTRING,
)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -830,6 +851,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
prediction_scores, mems = outputs[:2]
"""
def __init__(self, config):
super(TransfoXLLMHeadModel, self).__init__(config)
self.transformer = TransfoXLModel(config)
......@@ -840,8 +862,9 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(config.vocab_size, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val)
self.crit = ProjectedAdaptiveLogSoftmax(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
)
self.init_weights()
def tie_weights(self):
......@@ -856,8 +879,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else:
if self.config.tie_weight:
for i in range(len(self.crit.out_layers)):
self._tie_or_clone_weights(self.crit.out_layers[i],
self.transformer.word_emb.emb_layers[i])
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
if self.config.tie_projs:
for i, tie_proj in enumerate(self.config.tie_projs):
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
......
......@@ -28,9 +28,9 @@ import torch.nn.functional as F
# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
# CUDA_MINOR = int(torch.version.cuda.split('.')[1])
class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, keep_order=False):
super(ProjectedAdaptiveLogSoftmax, self).__init__()
self.n_token = n_token
......@@ -55,23 +55,19 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if div_val == 1:
for i in range(len(self.cutoffs)):
if d_proj != d_embed:
self.out_projs.append(
nn.Parameter(torch.FloatTensor(d_proj, d_embed))
)
self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
else:
self.out_projs.append(None)
self.out_layers.append(nn.Linear(d_embed, n_token))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
self.out_projs.append(
nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))
)
self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
self.out_layers.append(nn.Linear(d_emb_i, r_idx - l_idx))
self.keep_order = keep_order
......@@ -90,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
return logit
def forward(self, hidden, labels=None, keep_order=False):
'''
"""
Params:
hidden :: [len*bsz x d_proj]
labels :: [len*bsz]
......@@ -102,20 +98,17 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
We could replace this implementation by the native PyTorch one
if their's had an option to set bias on all clusters in the native one.
here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
'''
"""
if labels is not None:
labels = labels.view(-1)
if hidden.size(0) != labels.size(0):
raise RuntimeError('Input and labels should have the same size '
'in the batch dimension.')
raise RuntimeError("Input and labels should have the same size " "in the batch dimension.")
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
if labels is not None:
out = -F.log_softmax(logit, dim=-1) \
.gather(1, labels.unsqueeze(1)).squeeze(1)
out = -F.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1)
else:
out = F.log_softmax(logit, dim=-1)
else:
......@@ -131,10 +124,8 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
bias_i = self.out_layers[i].bias
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
......@@ -171,7 +162,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if labels is not None:
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
......@@ -179,22 +170,22 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
if labels is not None:
logprob_i = head_logprob_i[:, cluster_prob_idx] \
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather(
1, target_i[:, None]
).squeeze(1)
else:
logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
out[:, l_idx:r_idx] = logprob_i
if labels is not None:
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
if (hasattr(self, "keep_order") and self.keep_order) or keep_order:
out.index_copy_(0, indices_i, -logprob_i)
else:
out[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
out[offset : offset + logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return out
def log_prob(self, hidden):
r""" Computes log probabilities for all :math:`n\_classes`
From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
......@@ -209,8 +200,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
- Output: :math:`(N, n\_classes)`
"""
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
return F.log_softmax(logit, dim=-1)
else:
# construct weights and biases
......@@ -225,10 +215,8 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
bias_i = self.out_layers[i].bias
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
......@@ -244,7 +232,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
if i == 0:
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
......@@ -270,10 +258,10 @@ class LogUniformSampler(object):
"""
with torch.no_grad():
self.range_max = range_max
log_indices = torch.arange(1., range_max+2., 1.).log_()
log_indices = torch.arange(1.0, range_max + 2.0, 1.0).log_()
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
self.log_q = (-(-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
self.n_sample = n_sample
......@@ -298,6 +286,7 @@ class LogUniformSampler(object):
samp_log_probs = self.log_q[neg_samples].to(device)
return true_log_probs, samp_log_probs, neg_samples
def sample_logits(embedding, bias, labels, inputs, sampler):
"""
embedding: an nn.Embedding layer
......@@ -313,19 +302,17 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
b1, b2 = labels.size(0), labels.size(1)
all_ids = torch.cat([labels.view(-1), neg_samples])
all_w = embedding(all_ids)
true_w = all_w[: -n_sample].view(b1, b2, -1)
sample_w = all_w[- n_sample:].view(n_sample, -1)
true_w = all_w[:-n_sample].view(b1, b2, -1)
sample_w = all_w[-n_sample:].view(n_sample, -1)
all_b = bias[all_ids]
true_b = all_b[: -n_sample].view(b1, b2)
sample_b = all_b[- n_sample:]
true_b = all_b[:-n_sample].view(b1, b2)
sample_b = all_b[-n_sample:]
hit = (labels[:, :, None] == neg_samples).detach()
true_logits = torch.einsum('ijk,ijk->ij',
[true_w, inputs]) + true_b - true_log_probs
sample_logits = torch.einsum('lk,ijk->ijl',
[sample_w, inputs]) + sample_b - samp_log_probs
true_logits = torch.einsum("ijk,ijk->ij", [true_w, inputs]) + true_b - true_log_probs
sample_logits = torch.einsum("lk,ijk->ijl", [sample_w, inputs]) + sample_b - samp_log_probs
sample_logits.masked_fill_(hit, -1e30)
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
......
......@@ -15,8 +15,7 @@
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json
......@@ -31,8 +30,15 @@ from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .configuration_utils import PretrainedConfig
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, DUMMY_INPUTS,
cached_path, hf_bucket_url, is_remote_url)
from .file_utils import (
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_NAME,
DUMMY_INPUTS,
cached_path,
hf_bucket_url,
is_remote_url,
)
logger = logging.getLogger(__name__)
......@@ -43,12 +49,14 @@ except ImportError:
class Identity(nn.Module):
r"""A placeholder identity operator that is argument-insensitive.
"""
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input
class PreTrainedModel(nn.Module):
r""" Base class for all models.
......@@ -78,7 +86,7 @@ class PreTrainedModel(nn.Module):
Returns:
torch.Tensor with dummy inputs
"""
return {'input_ids': torch.tensor(DUMMY_INPUTS)}
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
......@@ -88,7 +96,8 @@ class PreTrainedModel(nn.Module):
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
)
)
# Save config in model
self.config = config
......@@ -136,14 +145,14 @@ class PreTrainedModel(nn.Module):
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
'constant',
0
"constant",
0,
)
if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens=None):
......@@ -244,10 +253,12 @@ class PreTrainedModel(nn.Module):
""" Save a model and its configuration file to a directory, so that it
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
assert os.path.isdir(
save_directory
), "Saving path should be a directory where the model and configuration can be saved"
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self
model_to_save = self.module if hasattr(self, "module") else self
# Save configuration file
model_to_save.config.save_pretrained(save_directory)
......@@ -329,21 +340,23 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
from_tf = kwargs.pop("from_tf", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
config_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
......@@ -367,43 +380,56 @@ class PreTrainedModel(nn.Module):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
pretrained_model_name_or_path))
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path
)
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
pretrained_model_name_or_path + ".index")
assert (
from_tf
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
pretrained_model_name_or_path + ".index"
)
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
if from_tf:
raise EnvironmentError("Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name.")
raise EnvironmentError(
"Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name."
)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file)
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
else:
msg = "Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to model weight files named one of {} but " \
msg = (
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to model weight files named one of {} but "
"couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
", ".join(cls.pretrained_model_archive_map.keys()),
archive_file,
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
)
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
else:
resolved_archive_file = None
......@@ -412,27 +438,32 @@ class PreTrainedModel(nn.Module):
if state_dict is None and not from_tf:
try:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except:
raise OSError("Unable to load weights from pytorch checkpoint file. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ")
raise OSError(
"Unable to load weights from pytorch checkpoint file. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
missing_keys = []
unexpected_keys = []
error_msgs = []
if from_tf:
if resolved_archive_file.endswith('.index'):
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
else:
# Load from our TensorFlow 2.0 checkpoints
try:
from transformers import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
except ImportError as e:
logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise e
else:
# Convert old format to new format if needed from a PyTorch state_dict
......@@ -440,10 +471,10 @@ class PreTrainedModel(nn.Module):
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
......@@ -451,39 +482,53 @@ class PreTrainedModel(nn.Module):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, prefix=''):
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ''
start_prefix = ""
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
start_prefix = cls.base_model_prefix + '.'
if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
if not hasattr(model, cls.base_model_prefix) and any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
):
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys
)
)
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
logger.info(
"Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys
)
)
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
......@@ -500,10 +545,22 @@ class PreTrainedModel(nn.Module):
return {"input_ids": input_ids}
@torch.no_grad()
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
bos_token_id=None, pad_token_id=None, eos_token_ids=None,
length_penalty=None, num_return_sequences=None):
def generate(
self,
input_ids=None,
max_length=None,
do_sample=None,
num_beams=None,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bos_token_id=None,
pad_token_id=None,
eos_token_ids=None,
length_penalty=None,
num_return_sequences=None,
):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
......@@ -543,8 +600,10 @@ class PreTrainedModel(nn.Module):
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError("You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)")
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
)
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
......@@ -557,7 +616,9 @@ class PreTrainedModel(nn.Module):
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
......@@ -575,13 +636,18 @@ class PreTrainedModel(nn.Module):
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer."
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert isinstance(eos_token_ids, (list, tuple)) and (
e >= 0 for e in eos_token_ids
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictely positive integer."
if input_ids is None:
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
input_ids = torch.full(
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device
)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
......@@ -592,28 +658,63 @@ class PreTrainedModel(nn.Module):
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_return_sequences, cur_len) # (batch_size * num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
if num_beams > 1:
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size,
length_penalty, num_beams, vocab_size)
output = self._generate_beam_search(
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
effective_batch_size,
length_penalty,
num_beams,
vocab_size,
)
else:
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size)
output = self._generate_no_beam_search(
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
effective_batch_size,
)
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
return output
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size):
def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
batch_size,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
......@@ -663,23 +764,38 @@ class PreTrainedModel(nn.Module):
return input_ids
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
length_penalty, num_beams, vocab_size):
def _generate_beam_search(
self,
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
batch_size,
length_penalty,
num_beams,
vocab_size,
):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)]
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
pasts = None # self.prepare_pasts()
......@@ -689,8 +805,8 @@ class PreTrainedModel(nn.Module):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
......@@ -703,25 +819,27 @@ class PreTrainedModel(nn.Module):
if temperature > 0 and temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
scores = top_k_top_p_filtering(
scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
# Compute next scores
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
else:
# do greedy beam search
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2*num_beams, dim=1, largest=True, sorted=True)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
......@@ -750,7 +868,9 @@ class PreTrainedModel(nn.Module):
# end of sentence, or next word
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
generated_hyps[batch_ex].add(input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item())
generated_hyps[batch_ex].add(
input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item()
)
else:
next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id))
......@@ -807,13 +927,13 @@ class PreTrainedModel(nn.Module):
# generate target batch
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i] - 1] = hypo
decoded[i, : tgt_len[i] - 1] = hypo
decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
return decoded
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1):
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
......@@ -849,7 +969,6 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
class BeamHypotheses(object):
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
......@@ -915,6 +1034,7 @@ class Conv1D(nn.Module):
class PoolerStartLogits(nn.Module):
""" Compute SQuAD start_logits from sequence hidden states. """
def __init__(self, config):
super(PoolerStartLogits, self).__init__()
self.dense = nn.Linear(config.hidden_size, 1)
......@@ -939,6 +1059,7 @@ class PoolerStartLogits(nn.Module):
class PoolerEndLogits(nn.Module):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
def __init__(self, config):
super(PoolerEndLogits, self).__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
......@@ -959,12 +1080,14 @@ class PoolerEndLogits(nn.Module):
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked.
"""
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
......@@ -982,6 +1105,7 @@ class PoolerEndLogits(nn.Module):
class PoolerAnswerClass(nn.Module):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
def __init__(self, config):
super(PoolerAnswerClass, self).__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
......@@ -1006,16 +1130,18 @@ class PoolerAnswerClass(nn.Module):
for each sample
"""
hsz = hidden_states.shape[-1]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
x = self.activation(x)
......@@ -1064,6 +1190,7 @@ class SQuADHead(nn.Module):
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
"""
def __init__(self, config):
super(SQuADHead, self).__init__()
self.start_n_top = config.start_n_top
......@@ -1073,8 +1200,9 @@ class SQuADHead(nn.Module):
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
def forward(self, hidden_states, start_positions=None, end_positions=None,
cls_index=None, is_impossible=None, p_mask=None):
def forward(
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None
):
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......@@ -1107,19 +1235,25 @@ class SQuADHead(nn.Module):
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1
) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
......@@ -1148,34 +1282,35 @@ class SequenceSummary(nn.Module):
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
"""
def __init__(self, config):
super(SequenceSummary, self).__init__()
self.summary_type = config.summary_type if hasattr(config, 'summary_type') else 'last'
if self.summary_type == 'attn':
self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last"
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = Identity()
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
self.activation = Identity()
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
if hasattr(config, "summary_activation") and config.summary_activation == "tanh":
self.activation = nn.Tanh()
self.first_dropout = Identity()
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = Identity()
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(self, hidden_states, cls_index=None):
......@@ -1185,21 +1320,21 @@ class SequenceSummary(nn.Module):
if summary_type == 'cls_index' and cls_index is None:
we take the last token of the sequence as classification token
"""
if self.summary_type == 'last':
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == 'first':
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == 'cls_index':
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == 'attn':
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
output = self.first_dropout(output)
......
......@@ -34,24 +34,21 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin",
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin",
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin",
}
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
for pos in range(n_pos)
])
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
......@@ -142,7 +139,7 @@ class MultiHeadAttention(nn.Module):
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs, qlen, dim = input.size()
if kv is None:
klen = qlen if cache is None else cache['slen'] + qlen
klen = qlen if cache is None else cache["slen"] + qlen
else:
klen = kv.size(1)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
......@@ -158,39 +155,39 @@ class MultiHeadAttention(nn.Module):
""" compute context """
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
elif cache is None or self.layer_id not in cache:
k = v = kv
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
if cache is not None:
if self.layer_id in cache:
if kv is None:
k_, v_ = cache[self.layer_id]
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen)
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
outputs = (self.out_lin(context),)
if self.output_attentions:
......@@ -199,7 +196,6 @@ class MultiHeadAttention(nn.Module):
class TransformerFFN(nn.Module):
def __init__(self, in_dim, dim_hidden, out_dim, config):
super(TransformerFFN, self).__init__()
self.dropout = config.dropout
......@@ -219,6 +215,7 @@ class XLMPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLMConfig
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
......@@ -235,7 +232,7 @@ class XLMPreTrainedModel(PreTrainedModel):
langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list}
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
def _init_weights(self, module):
""" Initialize the weights. """
......@@ -245,8 +242,8 @@ class XLMPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear):
if self.config is not None and self.config.init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, 0.)
if hasattr(module, "bias") and module.bias is not None:
nn.init.constant_(module.bias, 0.0)
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
......@@ -327,8 +324,12 @@ XLM_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class XLMModel(XLMPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -351,7 +352,8 @@ class XLMModel(XLMPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config): #, dico, is_encoder, with_output):
def __init__(self, config): # , dico, is_encoder, with_output):
super(XLMModel, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
......@@ -377,13 +379,13 @@ class XLMModel(XLMPreTrainedModel):
# assert len(self.id2lang) == len(self.lang2id) == self.n_langs
# model parameters
self.dim = config.emb_dim # 512 by default
self.dim = config.emb_dim # 512 by default
self.hidden_dim = self.dim * 4 # 2048 by default
self.n_heads = config.n_heads # 8 by default
self.n_heads = config.n_heads # 8 by default
self.n_layers = config.n_layers
self.dropout = config.dropout
self.attention_dropout = config.attention_dropout
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
# embeddings
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
......@@ -435,8 +437,18 @@ class XLMModel(XLMPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.attentions[layer].prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None): # removed: src_enc=None, src_len=None
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
): # removed: src_enc=None, src_len=None
if input_ids is not None:
bs, slen = input_ids.size()
else:
......@@ -446,7 +458,7 @@ class XLMModel(XLMPreTrainedModel):
if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
else:
lengths = torch.LongTensor([slen]*bs)
lengths = torch.LongTensor([slen] * bs)
# mask = input_ids != self.pad_index
# check inputs
......@@ -488,14 +500,18 @@ class XLMModel(XLMPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layers
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache['slen']
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
......@@ -550,7 +566,7 @@ class XLMModel(XLMPreTrainedModel):
# update cache length
if cache is not None:
cache['slen'] += tensor.size(1)
cache["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
......@@ -567,6 +583,7 @@ class XLMPredLayer(nn.Module):
"""
Prediction layer (cross_entropy or adaptive_softmax).
"""
def __init__(self, config):
super(XLMPredLayer, self).__init__()
self.asm = config.asm
......@@ -593,7 +610,7 @@ class XLMPredLayer(nn.Module):
scores = self.proj(x)
outputs = (scores,) + outputs
if y is not None:
loss = F.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction='elementwise_mean')
loss = F.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="elementwise_mean")
outputs = (loss,) + outputs
else:
scores = self.proj.log_prob(x)
......@@ -605,9 +622,12 @@ class XLMPredLayer(nn.Module):
return outputs
@add_start_docstrings("""The XLM Model transformer with a language modeling head on top
@add_start_docstrings(
"""The XLM Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class XLMWithLMHeadModel(XLMPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -639,6 +659,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(XLMWithLMHeadModel, self).__init__(config)
self.transformer = XLMModel(config)
......@@ -661,17 +682,30 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs = None
return {"input_ids": input_ids, "langs": langs}
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
output = transformer_outputs[0]
outputs = self.pred_layer(output, labels)
......@@ -680,9 +714,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return outputs
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings(
"""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class XLMForSequenceClassification(XLMPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -714,6 +751,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(XLMForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
......@@ -723,17 +761,30 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
......@@ -753,9 +804,12 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
return outputs
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -799,6 +853,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLMForQuestionAnsweringSimple, self).__init__(config)
......@@ -807,17 +862,31 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = transformer_outputs[0]
......@@ -826,7 +895,10 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,)
outputs = (
start_logits,
end_logits,
)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
......@@ -849,9 +921,12 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
return outputs
@add_start_docstrings("""XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
XLM_START_DOCSTRING,
XLM_INPUTS_DOCSTRING,
)
class XLMForQuestionAnswering(XLMPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -895,6 +970,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLMForQuestionAnswering, self).__init__(config)
......@@ -903,23 +979,45 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None,
is_impossible=None, cls_index=None, p_mask=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
is_impossible=None,
cls_index=None,
p_mask=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
output = transformer_outputs[0]
outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)
outputs = self.qa_outputs(
output,
start_positions=start_positions,
end_positions=end_positions,
cls_index=cls_index,
is_impossible=is_impossible,
p_mask=p_mask,
)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
......
......@@ -15,24 +15,29 @@
# limitations under the License.
"""PyTorch XLM-RoBERTa model. """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
from .modeling_roberta import (
RobertaModel,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaForMultipleChoice,
RobertaForTokenClassification,
)
from .configuration_xlm_roberta import XLMRobertaConfig
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-pytorch_model.bin",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-pytorch_model.bin",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin",
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-pytorch_model.bin",
"xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin",
"xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin",
}
......@@ -105,8 +110,12 @@ XLM_ROBERTA_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
XLM_ROBERTA_START_DOCSTRING, XLM_ROBERTA_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
XLM_ROBERTA_START_DOCSTRING,
XLM_ROBERTA_INPUTS_DOCSTRING,
)
class XLMRobertaModel(RobertaModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -154,8 +163,11 @@ class XLMRobertaModel(RobertaModel):
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""XLM-RoBERTa Model with a `language modeling` head on top. """,
XLM_ROBERTA_START_DOCSTRING, XLM_ROBERTA_INPUTS_DOCSTRING)
@add_start_docstrings(
"""XLM-RoBERTa Model with a `language modeling` head on top. """,
XLM_ROBERTA_START_DOCSTRING,
XLM_ROBERTA_INPUTS_DOCSTRING,
)
class XLMRobertaForMaskedLM(RobertaForMaskedLM):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -190,9 +202,12 @@ class XLMRobertaForMaskedLM(RobertaForMaskedLM):
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
@add_start_docstrings(
"""XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
XLM_ROBERTA_START_DOCSTRING, XLM_ROBERTA_INPUTS_DOCSTRING)
XLM_ROBERTA_START_DOCSTRING,
XLM_ROBERTA_INPUTS_DOCSTRING,
)
class XLMRobertaForSequenceClassification(RobertaForSequenceClassification):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -228,9 +243,12 @@ class XLMRobertaForSequenceClassification(RobertaForSequenceClassification):
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of
@add_start_docstrings(
"""XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
XLM_ROBERTA_START_DOCSTRING, XLM_ROBERTA_INPUTS_DOCSTRING)
XLM_ROBERTA_START_DOCSTRING,
XLM_ROBERTA_INPUTS_DOCSTRING,
)
class XLMRobertaForMultipleChoice(RobertaForMultipleChoice):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -262,9 +280,12 @@ class XLMRobertaForMultipleChoice(RobertaForMultipleChoice):
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""XLM-RoBERTa Model with a token classification head on top (a linear layer on top of
@add_start_docstrings(
"""XLM-RoBERTa Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLM_ROBERTA_START_DOCSTRING, XLM_ROBERTA_INPUTS_DOCSTRING)
XLM_ROBERTA_START_DOCSTRING,
XLM_ROBERTA_INPUTS_DOCSTRING,
)
class XLMRobertaForTokenClassification(RobertaForTokenClassification):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......
......@@ -29,7 +29,14 @@ from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .modeling_utils import PreTrainedModel, prune_linear_layer, SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits
from .modeling_utils import (
PreTrainedModel,
prune_linear_layer,
SequenceSummary,
PoolerAnswerClass,
PoolerEndLogits,
PoolerStartLogits,
)
from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings
......@@ -37,8 +44,8 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin",
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
}
......@@ -50,44 +57,53 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
tf_to_pt_map = {}
if hasattr(model, 'transformer'):
if hasattr(model, 'lm_loss'):
if hasattr(model, "transformer"):
if hasattr(model, "lm_loss"):
# We will load also the output bias
tf_to_pt_map['model/lm_loss/bias'] = model.lm_loss.bias
if hasattr(model, 'sequence_summary') and 'model/sequnece_summary/summary/kernel' in tf_weights:
tf_to_pt_map["model/lm_loss/bias"] = model.lm_loss.bias
if hasattr(model, "sequence_summary") and "model/sequnece_summary/summary/kernel" in tf_weights:
# We will load also the sequence summary
tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight
tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias
if hasattr(model, 'logits_proj') and config.finetuning_task is not None \
and 'model/regression_{}/logit/kernel'.format(config.finetuning_task) in tf_weights:
tf_to_pt_map['model/regression_{}/logit/kernel'.format(config.finetuning_task)] = model.logits_proj.weight
tf_to_pt_map['model/regression_{}/logit/bias'.format(config.finetuning_task)] = model.logits_proj.bias
tf_to_pt_map["model/sequnece_summary/summary/kernel"] = model.sequence_summary.summary.weight
tf_to_pt_map["model/sequnece_summary/summary/bias"] = model.sequence_summary.summary.bias
if (
hasattr(model, "logits_proj")
and config.finetuning_task is not None
and "model/regression_{}/logit/kernel".format(config.finetuning_task) in tf_weights
):
tf_to_pt_map["model/regression_{}/logit/kernel".format(config.finetuning_task)] = model.logits_proj.weight
tf_to_pt_map["model/regression_{}/logit/bias".format(config.finetuning_task)] = model.logits_proj.bias
# Now load the rest of the transformer
model = model.transformer
# Embeddings and output
tf_to_pt_map.update({'model/transformer/word_embedding/lookup_table': model.word_embedding.weight,
'model/transformer/mask_emb/mask_emb': model.mask_emb})
tf_to_pt_map.update(
{
"model/transformer/word_embedding/lookup_table": model.word_embedding.weight,
"model/transformer/mask_emb/mask_emb": model.mask_emb,
}
)
# Transformer blocks
for i, b in enumerate(model.layer):
layer_str = "model/transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.rel_attn.o,
layer_str + "rel_attn/q/kernel": b.rel_attn.q,
layer_str + "rel_attn/k/kernel": b.rel_attn.k,
layer_str + "rel_attn/r/kernel": b.rel_attn.r,
layer_str + "rel_attn/v/kernel": b.rel_attn.v,
layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight,
layer_str + "ff/layer_1/bias": b.ff.layer_1.bias,
layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight,
layer_str + "ff/layer_2/bias": b.ff.layer_2.bias,
})
tf_to_pt_map.update(
{
layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.rel_attn.o,
layer_str + "rel_attn/q/kernel": b.rel_attn.q,
layer_str + "rel_attn/k/kernel": b.rel_attn.k,
layer_str + "rel_attn/r/kernel": b.rel_attn.r,
layer_str + "rel_attn/v/kernel": b.rel_attn.v,
layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight,
layer_str + "ff/layer_1/bias": b.ff.layer_1.bias,
layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight,
layer_str + "ff/layer_2/bias": b.ff.layer_2.bias,
}
)
# Relative positioning biases
if config.untie_r:
......@@ -105,13 +121,17 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
r_w_list = [model.r_w_bias]
r_s_list = [model.r_s_bias]
seg_embed_list = [model.seg_embed]
tf_to_pt_map.update({
'model/transformer/r_r_bias': r_r_list,
'model/transformer/r_w_bias': r_w_list,
'model/transformer/r_s_bias': r_s_list,
'model/transformer/seg_embed': seg_embed_list})
tf_to_pt_map.update(
{
"model/transformer/r_r_bias": r_r_list,
"model/transformer/r_w_bias": r_w_list,
"model/transformer/r_s_bias": r_s_list,
"model/transformer/seg_embed": seg_embed_list,
}
)
return tf_to_pt_map
def load_tf_weights_in_xlnet(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
......@@ -119,8 +139,10 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
import numpy as np
import tensorflow as tf
except ImportError:
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
logger.error(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
......@@ -141,7 +163,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name and ('ff' in name or 'summary' in name or 'logit' in name):
if "kernel" in name and ("ff" in name or "summary" in name or "logit" in name):
logger.info("Transposing")
array = np.transpose(array)
if isinstance(pointer, list):
......@@ -165,10 +187,10 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
tf_weights.pop(name + "/Adam", None)
tf_weights.pop(name + "/Adam_1", None)
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
return model
......@@ -199,7 +221,8 @@ class XLNetRelativeAttention(nn.Module):
if config.d_model % config.n_head != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.d_model, config.n_head))
"heads (%d)" % (config.d_model, config.n_head)
)
self.n_head = config.n_head
self.d_head = config.d_head
......@@ -242,7 +265,7 @@ class XLNetRelativeAttention(nn.Module):
x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1)
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
# Note: the tensor-slice form was faster in my testing than torch.index_select
# However, tracing doesn't like the nature of the slice, and if klen changes
# during the run then it'll fail, whereas index_select will be fine.
......@@ -255,27 +278,27 @@ class XLNetRelativeAttention(nn.Module):
"""Core relative positional attention operations."""
# content based attention score
ac = torch.einsum('ibnd,jbnd->bnij', q_head + self.r_w_bias, k_head_h)
ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
# position based attention score
bd = torch.einsum('ibnd,jbnd->bnij', q_head + self.r_r_bias, k_head_r)
bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = torch.einsum('ibnd,snd->ibns', q_head + self.r_s_bias, self.seg_embed)
ef = torch.einsum('ijbs,ibns->bnij', seg_mat, ef)
ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if attn_mask.dtype == torch.float16:
attn_score = attn_score - 65500 * torch.einsum('ijbn->bnij', attn_mask)
attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
else:
attn_score = attn_score - 1e30 * torch.einsum('ijbn->bnij', attn_mask)
attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
# attention probability
attn_prob = F.softmax(attn_score, dim=3)
......@@ -283,20 +306,20 @@ class XLNetRelativeAttention(nn.Module):
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * torch.einsum('ijbn->bnij', head_mask)
attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)
# attention output
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions:
return attn_vec, torch.einsum('bnij->ijbn', attn_prob)
return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
return attn_vec
def post_attention(self, h, attn_vec, residual=True):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
attn_out = torch.einsum('ibnd,hnd->ibh', attn_vec, self.o)
attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out)
if residual:
......@@ -305,10 +328,7 @@ class XLNetRelativeAttention(nn.Module):
return output
def forward(self, h, g,
attn_mask_h, attn_mask_g,
r, seg_mat,
mems=None, target_mapping=None, head_mask=None):
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
if g is not None:
###### Two-stream attention with relative positional encoding.
# content based attention score
......@@ -318,21 +338,22 @@ class XLNetRelativeAttention(nn.Module):
cat = h
# content-based key head
k_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.k)
k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
# content-based value head
v_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.v)
v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
# position-based key head
k_head_r = torch.einsum('ibh,hnd->ibnd', r, self.r)
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream
# content-stream query head
q_head_h = torch.einsum('ibh,hnd->ibnd', h, self.q)
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
# core attention ops
attn_vec_h = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask)
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
)
if self.output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
......@@ -342,21 +363,23 @@ class XLNetRelativeAttention(nn.Module):
##### g-stream
# query-stream query head
q_head_g = torch.einsum('ibh,hnd->ibnd', g, self.q)
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
# core attention ops
if target_mapping is not None:
q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask)
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask)
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
......@@ -375,16 +398,17 @@ class XLNetRelativeAttention(nn.Module):
cat = h
# content heads
q_head_h = torch.einsum('ibh,hnd->ibnd', h, self.q)
k_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.k)
v_head_h = torch.einsum('ibh,hnd->ibnd', cat, self.v)
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
# positional heads
k_head_r = torch.einsum('ibh,hnd->ibnd', r, self.r)
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
# core attention ops
attn_vec = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask)
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
)
if self.output_attentions:
attn_vec, attn_prob = attn_vec
......@@ -398,6 +422,7 @@ class XLNetRelativeAttention(nn.Module):
outputs = outputs + (attn_prob,)
return outputs
class XLNetFeedForward(nn.Module):
def __init__(self, config):
super(XLNetFeedForward, self).__init__()
......@@ -405,8 +430,9 @@ class XLNetFeedForward(nn.Module):
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout)
if isinstance(config.ff_activation, str) or \
(sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)):
if isinstance(config.ff_activation, str) or (
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)
):
self.activation_function = ACT2FN[config.ff_activation]
else:
self.activation_function = config.ff_activation
......@@ -421,6 +447,7 @@ class XLNetFeedForward(nn.Module):
output = self.layer_norm(output + inp)
return output
class XLNetLayer(nn.Module):
def __init__(self, config):
super(XLNetLayer, self).__init__()
......@@ -428,12 +455,20 @@ class XLNetLayer(nn.Module):
self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout)
def forward(self, output_h, output_g,
attn_mask_h, attn_mask_g,
r, seg_mat, mems=None, target_mapping=None, head_mask=None):
outputs = self.rel_attn(output_h, output_g, attn_mask_h, attn_mask_g,
r, seg_mat, mems=mems, target_mapping=target_mapping,
head_mask=head_mask)
def forward(
self, output_h, output_g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None
):
outputs = self.rel_attn(
output_h,
output_g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems=mems,
target_mapping=target_mapping,
head_mask=head_mask,
)
output_h, output_g = outputs[:2]
if output_g is not None:
......@@ -448,6 +483,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLNetConfig
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet
......@@ -466,12 +502,20 @@ class XLNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, XLNetRelativeAttention):
for param in [module.q, module.k, module.v, module.o, module.r,
module.r_r_bias, module.r_s_bias, module.r_w_bias,
module.seg_embed]:
for param in [
module.q,
module.k,
module.v,
module.o,
module.r,
module.r_r_bias,
module.r_s_bias,
module.r_w_bias,
module.seg_embed,
]:
param.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, XLNetModel):
module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
XLNET_START_DOCSTRING = r""" The XLNet model was proposed in
......@@ -564,8 +608,12 @@ XLNET_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
@add_start_docstrings(
"The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetModel(XLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......@@ -594,6 +642,7 @@ class XLNetModel(XLNetPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(XLNetModel, self).__init__(config)
self.output_attentions = config.output_attentions
......@@ -658,18 +707,18 @@ class XLNetModel(XLNetPreTrainedModel):
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
curr_out = curr_out[: self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
new_mem = curr_out[-self.mem_len :]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len :]
return new_mem.detach()
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
sinusoid_inp = torch.einsum("i,d->id", pos_seq, inv_freq)
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
pos_emb = pos_emb[:, None, :]
......@@ -683,14 +732,14 @@ class XLNetModel(XLNetPreTrainedModel):
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
if self.attn_type == 'bi':
if self.attn_type == "bi":
# beg, end = klen - 1, -qlen
beg, end = klen, -qlen
elif self.attn_type == 'uni':
elif self.attn_type == "uni":
# beg, end = klen - 1, -1
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
......@@ -701,8 +750,8 @@ class XLNetModel(XLNetPreTrainedModel):
bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
if bsz is not None:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
......@@ -717,8 +766,18 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None):
def forward(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
......@@ -739,7 +798,6 @@ class XLNetModel(XLNetPreTrainedModel):
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
......@@ -748,13 +806,13 @@ class XLNetModel(XLNetPreTrainedModel):
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
elif self.attn_type == "bi":
attn_mask = None
else:
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
raise ValueError("Unsupported attention type: {}".format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
......@@ -799,9 +857,9 @@ class XLNetModel(XLNetPreTrainedModel):
output_h = self.dropout(word_emb_k)
if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q)
else:
output_g = None
......@@ -836,7 +894,9 @@ class XLNetModel(XLNetPreTrainedModel):
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layer
......@@ -853,9 +913,17 @@ class XLNetModel(XLNetPreTrainedModel):
if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask[i])
outputs = layer_module(
output_h,
output_g,
attn_mask_h=non_tgt_mask,
attn_mask_g=attn_mask,
r=pos_emb,
seg_mat=seg_mat,
mems=mems[i],
target_mapping=target_mapping,
head_mask=head_mask[i],
)
output_h, output_g = outputs[:2]
if self.output_attentions:
attentions.append(outputs[2])
......@@ -881,7 +949,9 @@ class XLNetModel(XLNetPreTrainedModel):
if self.output_attentions:
if target_mapping is not None:
# when target_mapping is provided, there are 2-tuple of attentions
attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions)
attentions = tuple(
tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions
)
else:
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs = outputs + (attentions,)
......@@ -889,9 +959,12 @@ class XLNetModel(XLNetPreTrainedModel):
return outputs # outputs, (new_mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a language modeling head on top
@add_start_docstrings(
"""XLNet Model with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetLMHeadModel(XLNetPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
......@@ -934,6 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
def __init__(self, config):
super(XLNetLMHeadModel, self).__init__(config)
self.attn_type = config.attn_type
......@@ -954,34 +1028,42 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# Build permutation mask so that previous tokens don't see last token
perm_mask = torch.zeros(
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=input_ids.device
)
perm_mask[:, :, -1] = 1.0
# We'll only predict the last token
target_mapping = torch.zeros(
(input_ids.shape[0], 1, input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
(input_ids.shape[0], 1, input_ids.shape[1]), dtype=torch.float, device=input_ids.device
)
target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping
}
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
def forward(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
logits = self.lm_loss(transformer_outputs[0])
......@@ -990,16 +1072,18 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
if labels is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)),
labels.view(-1))
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
@add_start_docstrings(
"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetForSequenceClassification(XLNetPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -1037,6 +1121,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(XLNetForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
......@@ -1047,17 +1132,30 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
output = transformer_outputs[0]
output = self.sequence_summary(output)
......@@ -1077,10 +1175,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of
@add_start_docstrings(
"""XLNet Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetForTokenClassification(XLNetPreTrainedModel):
r"""
Inputs:
......@@ -1135,6 +1236,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
scores = outputs[0]
"""
def __init__(self, config):
super(XLNetForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
......@@ -1144,18 +1246,31 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
......@@ -1177,9 +1292,12 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
@add_start_docstrings(
"""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetForMultipleChoice(XLNetPreTrainedModel):
r"""
Inputs:
......@@ -1239,6 +1357,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
def __init__(self, config):
super(XLNetForMultipleChoice, self).__init__(config)
......@@ -1248,9 +1367,19 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None, inputs_embeds=None):
def forward(
self,
input_ids=None,
token_type_ids=None,
input_mask=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
labels=None,
head_mask=None,
inputs_embeds=None,
):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
......@@ -1258,18 +1387,26 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask, inputs_embeds=inputs_embeds)
transformer_outputs = self.transformer(
flat_input_ids,
token_type_ids=flat_token_type_ids,
input_mask=flat_input_mask,
attention_mask=flat_attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
output = transformer_outputs[0]
output = self.sequence_summary(output)
logits = self.logits_proj(output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
outputs = (reshaped_logits,) + transformer_outputs[
1:
] # Keep mems, hidden states, attentions if there are in it
if labels is not None:
loss_fct = CrossEntropyLoss()
......@@ -1279,9 +1416,12 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -1325,6 +1465,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLNetForQuestionAnsweringSimple, self).__init__(config)
self.num_labels = config.num_labels
......@@ -1334,19 +1475,32 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0]
......@@ -1376,9 +1530,12 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@add_start_docstrings(
"""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING,
)
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......@@ -1440,6 +1597,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLNetForQuestionAnswering, self).__init__(config)
self.start_n_top = config.start_n_top
......@@ -1452,18 +1610,34 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
def forward(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
is_impossible=None,
cls_index=None,
p_mask=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......@@ -1497,24 +1671,34 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1
) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
start_states = torch.einsum(
"blh,bl->bh", hidden_states, start_log_probs
) # get the representation of START as weighted sum of hidden states
cls_logits = self.answer_class(
hidden_states, start_states=start_states, cls_index=cls_index
) # Shape (batch size,): one single `cls_logits` for each sample
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
......
......@@ -34,10 +34,11 @@ def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1
""" Create a schedule with a constant learning rate preceded by a warmup
period during which the learning rate increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1.0, num_warmup_steps))
return 1.
return 1.0
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
......@@ -46,40 +47,47 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
""" Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1):
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function between 0 and `pi * cycles` after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress)))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1):
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1
):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function with several hard restarts, after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.:
return 0.
return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.))))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
return LambdaLR(optimizer, lr_lambda, last_epoch)
......@@ -94,17 +102,17 @@ class AdamW(Optimizer):
weight_decay (float): Weight decay. Default: 0.0
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
correct_bias=correct_bias)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
super(AdamW, self).__init__(params, defaults)
def step(self, closure=None):
......@@ -119,38 +127,38 @@ class AdamW(Optimizer):
loss = closure()
for group in self.param_groups:
for p in group['params']:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state["step"] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state['step'] += 1
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group['lr']
if group['correct_bias']: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
......@@ -163,7 +171,7 @@ class AdamW(Optimizer):
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group['weight_decay'] > 0.0:
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
if group["weight_decay"] > 0.0:
p.data.add_(-group["lr"] * group["weight_decay"], p.data)
return loss
......@@ -24,70 +24,64 @@ import tensorflow as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self):
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'power': self.power,
'name': self.name
}
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or "WarmUp") as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name,
)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"decay_schedule_fn": self.decay_schedule_fn,
"warmup_steps": self.warmup_steps,
"power": self.power,
"name": self.name,
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0)
if num_warmup_steps:
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
return optimizer
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=0.0
)
if num_warmup_steps:
learning_rate_fn = WarmUp(
initial_learning_rate=init_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=num_warmup_steps
)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=["layer_norm", "bias"],
)
return optimizer
class AdamWeightDecay(tf.keras.optimizers.Adam):
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
Just adding the square of the weights to the loss function is *not* the
correct way of using L2 regularization/weight decay with Adam, since that will
......@@ -98,99 +92,94 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
the loss with plain (non-momentum) SGD.
"""
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {'WarmUp': WarmUp}
return super(AdamWeightDecay, cls).from_config(
config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
apply_state)
apply_state['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate')
def _decay_weights_op(self, var, learning_rate, apply_state):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var *
apply_state['weight_decay_rate'],
use_locking=self._use_locking)
return tf.no_op()
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
if apply_state is None:
return self._decayed_lr_t[var_dtype], {}
apply_state = apply_state or {}
coefficients = apply_state.get((var_device, var_dtype))
if coefficients is None:
coefficients = self._fallback_apply_state(var_device, var_dtype)
apply_state[(var_device, var_dtype)] = coefficients
return coefficients['lr_t'], dict(apply_state=apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices, **kwargs)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config.update({
'weight_decay_rate': self.weight_decay_rate,
})
return config
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False
if self._include_in_weight_decay:
for r in self._include_in_weight_decay:
if re.search(r, param_name) is not None:
return True
if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
def __init__(
self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None,
name="AdamWeightDecay",
**kwargs
):
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {"WarmUp": WarmUp}
return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
apply_state["weight_decay_rate"] = tf.constant(self.weight_decay_rate, name="adam_weight_decay_rate")
def _decay_weights_op(self, var, learning_rate, apply_state):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var * apply_state["weight_decay_rate"], use_locking=self._use_locking
)
return tf.no_op()
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
if apply_state is None:
return self._decayed_lr_t[var_dtype], {}
apply_state = apply_state or {}
coefficients = apply_state.get((var_device, var_dtype))
if coefficients is None:
coefficients = self._fallback_apply_state(var_device, var_dtype)
apply_state[(var_device, var_dtype)] = coefficients
return coefficients["lr_t"], dict(apply_state=apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config.update(
{"weight_decay_rate": self.weight_decay_rate,}
)
return config
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False
if self._include_in_weight_decay:
for r in self._include_in_weight_decay:
if re.search(r, param_name) is not None:
return True
if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
......@@ -201,10 +190,8 @@ class GradientAccumulator(object):
"""Initializes the accumulator."""
self._gradients = []
self._accum_steps = tf.Variable(
initial_value=0,
dtype=tf.int64,
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
initial_value=0, dtype=tf.int64, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
)
@property
def step(self):
......@@ -214,12 +201,19 @@ class GradientAccumulator(object):
@property
def gradients(self):
"""The accumulated gradients."""
return list(gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients())
return list(
gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients()
)
def __call__(self, gradients):
"""Accumulates :obj:`gradients`."""
if not self._gradients:
self._gradients.extend([tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient for gradient in gradients])
self._gradients.extend(
[
tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient
for gradient in gradients
]
)
if len(gradients) != len(self._gradients):
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
......@@ -249,6 +243,9 @@ class GradientAccumulator(object):
if replica_context is None or tf.distribute.get_strategy().num_replicas_in_sync == 1:
return self._gradients
return (gradient.device_map.select_for_current_replica(gradient.values, replica_context) for gradient in self._gradients)
return (
gradient.device_map.select_for_current_replica(gradient.values, replica_context)
for gradient in self._gradients
)
else:
return self._gradients
......@@ -30,25 +30,42 @@ from typing import Union, Optional, Tuple, List, Dict
import numpy as np
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PretrainedConfig, ModelCard, SquadExample,
squad_convert_examples_to_features, is_tf_available,
is_torch_available, BasicTokenizer,
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP)
from transformers import (
AutoConfig,
AutoTokenizer,
PreTrainedTokenizer,
PretrainedConfig,
ModelCard,
SquadExample,
squad_convert_examples_to_features,
is_tf_available,
is_torch_available,
BasicTokenizer,
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
)
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModel, TFAutoModelForSequenceClassification, \
TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
from transformers import (
TFAutoModel,
TFAutoModelForSequenceClassification,
TFAutoModelForQuestionAnswering,
TFAutoModelForTokenClassification,
)
if is_torch_available():
import torch
from transformers import AutoModel, AutoModelForSequenceClassification, \
AutoModelForQuestionAnswering, AutoModelForTokenClassification
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForTokenClassification,
)
logger = logging.getLogger(__name__)
def get_framework(model=None):
""" Select framework (TensorFlow/PyTorch) to use.
If both frameworks are installed and no specific model is provided, defaults to using PyTorch.
......@@ -56,20 +73,24 @@ def get_framework(model=None):
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
# Both framework are available but the use supplied a model class instance.
# Try to guess which framework to use from the model classname
framework = 'tf' if model.__class__.__name__.startswith('TF') else 'pt'
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
elif not is_tf_available() and not is_torch_available():
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.")
raise ImportError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
else:
# framework = 'tf' if is_tf_available() else 'pt'
framework = 'pt' if is_torch_available() else 'tf'
framework = "pt" if is_torch_available() else "tf"
return framework
class ArgumentHandler(ABC):
"""
Base interface for handling varargs for each Pipeline
"""
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
......@@ -79,11 +100,12 @@ class DefaultArgumentHandler(ArgumentHandler):
"""
Default varargs argument parser handling parameters for each Pipeline
"""
def __call__(self, *args, **kwargs):
if 'X' in kwargs:
return kwargs['X']
elif 'data' in kwargs:
return kwargs['data']
if "X" in kwargs:
return kwargs["X"]
elif "data" in kwargs:
return kwargs["data"]
elif len(args) == 1:
if isinstance(args[0], list):
return args[0]
......@@ -91,7 +113,7 @@ class DefaultArgumentHandler(ArgumentHandler):
return [args[0]]
elif len(args) > 1:
return list(args)
raise ValueError('Unable to infer the format of the provided data (X=, data=, ...)')
raise ValueError("Unable to infer the format of the provided data (X=, data=, ...)")
class PipelineDataFormat:
......@@ -105,24 +127,25 @@ class PipelineDataFormat:
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
"""
SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
SUPPORTED_FORMATS = ["json", "csv", "pipe"]
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
self.output_path = output_path
self.input_path = input_path
self.column = column.split(',') if column is not None else ['']
self.column = column.split(",") if column is not None else [""]
self.is_multi_columns = len(self.column) > 1
if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
if output_path is not None and not overwrite:
if exists(abspath(self.output_path)):
raise OSError('{} already exists on disk'.format(self.output_path))
raise OSError("{} already exists on disk".format(self.output_path))
if input_path is not None:
if not exists(abspath(self.input_path)):
raise OSError('{} doesnt exist on disk'.format(self.input_path))
raise OSError("{} doesnt exist on disk".format(self.input_path))
@abstractmethod
def __iter__(self):
......@@ -144,23 +167,25 @@ class PipelineDataFormat:
:return: (str) Path where the data has been saved
"""
path, _ = os.path.splitext(self.output_path)
binary_path = os.path.extsep.join((path, 'pickle'))
binary_path = os.path.extsep.join((path, "pickle"))
with open(binary_path, 'wb+') as f_output:
with open(binary_path, "wb+") as f_output:
pickle.dump(data, f_output)
return binary_path
@staticmethod
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
if format == 'json':
def from_str(
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False
):
if format == "json":
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'csv':
elif format == "csv":
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'pipe':
elif format == "pipe":
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
else:
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
raise KeyError("Unknown reader {} (Available reader are json/csv/pipe)".format(format))
class CsvPipelineDataFormat(PipelineDataFormat):
......@@ -168,7 +193,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
super().__init__(output_path, input_path, column, overwrite=overwrite)
def __iter__(self):
with open(self.input_path, 'r') as f:
with open(self.input_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
if self.is_multi_columns:
......@@ -177,7 +202,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
yield row[self.column[0]]
def save(self, data: List[dict]):
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
if len(data) > 0:
writer = csv.DictWriter(f, list(data[0].keys()))
writer.writeheader()
......@@ -188,7 +213,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column, overwrite=overwrite)
with open(input_path, 'r') as f:
with open(input_path, "r") as f:
self._entries = json.load(f)
def __iter__(self):
......@@ -199,7 +224,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
yield entry[self.column[0]]
def save(self, data: dict):
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
json.dump(data, f)
......@@ -210,12 +235,13 @@ class PipedPipelineDataFormat(PipelineDataFormat):
If columns are provided, then the output will be a dictionary with {column_x: value_x}
"""
def __iter__(self):
for line in sys.stdin:
# Split for multi-columns
if '\t' in line:
if "\t" in line:
line = line.split('\t')
line = line.split("\t")
if self.column:
# Dictionary to map arguments
yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
......@@ -232,8 +258,8 @@ class PipedPipelineDataFormat(PipelineDataFormat):
def save_binary(self, data: Union[dict, List[dict]]) -> str:
if self.output_path is None:
raise KeyError(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
"When using piped input on pipeline outputting large object requires an output file path. "
"Please provide such output path through --output argument."
)
return super().save_binary(data)
......@@ -298,10 +324,16 @@ class Pipeline(_ScikitCompat):
default_input_names = None
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False):
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
):
if framework is None:
framework = get_framework()
......@@ -315,8 +347,8 @@ class Pipeline(_ScikitCompat):
self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling
if self.device >= 0 and self.framework == 'pt':
self.model = self.model.to('cuda:{}'.format(self.device))
if self.device >= 0 and self.framework == "pt":
self.model = self.model.to("cuda:{}".format(self.device))
def save_pretrained(self, save_directory):
"""
......@@ -356,8 +388,8 @@ class Pipeline(_ScikitCompat):
Returns:
Context manager
"""
if self.framework == 'tf':
with tf.device('/CPU:0' if self.device == -1 else '/device:GPU:{}'.format(self.device)):
if self.framework == "tf":
with tf.device("/CPU:0" if self.device == -1 else "/device:GPU:{}".format(self.device)):
yield
else:
if self.device >= 0:
......@@ -372,11 +404,11 @@ class Pipeline(_ScikitCompat):
Returns:
dict holding all the required parameters for model's forward
"""
args = ['input_ids', 'attention_mask']
args = ["input_ids", "attention_mask"]
model_type = type(self.model).__name__.lower()
if 'distilbert' not in model_type and 'xlm' not in model_type:
args += ['token_type_ids']
if "distilbert" not in model_type and "xlm" not in model_type:
args += ["token_type_ids"]
# PR #1548 (CLI) There is an issue with attention_mask
# if 'xlnet' in model_type or 'xlm' in model_type:
......@@ -394,9 +426,7 @@ class Pipeline(_ScikitCompat):
# Encode for forward
with self.device_placement():
inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True,
return_tensors=self.framework,
max_length=self.tokenizer.max_len
inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
)
# Filter out features not available on specific models
......@@ -411,7 +441,7 @@ class Pipeline(_ScikitCompat):
Returns:
Numpy array
"""
if self.framework == 'tf':
if self.framework == "tf":
# TODO trace model
predictions = self.model(inputs, training=False)[0]
else:
......@@ -426,19 +456,24 @@ class FeatureExtractionPipeline(Pipeline):
Feature extraction pipeline using Model head.
"""
def __init__(self, model,
tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1):
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=True)
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=True,
)
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
......@@ -452,7 +487,7 @@ class TextClassificationPipeline(Pipeline):
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
return [{'label': self.model.config.id2label[item.argmax()], 'score': item.max()} for item in scores]
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores]
class NerPipeline(Pipeline):
......@@ -460,19 +495,28 @@ class NerPipeline(Pipeline):
Named Entity Recognition pipeline using ModelForTokenClassification head.
"""
default_input_names = 'sequences'
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False, ignore_labels=['O']):
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=binary_output)
default_input_names = "sequences"
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
ignore_labels=["O"],
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=binary_output,
)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
......@@ -485,19 +529,20 @@ class NerPipeline(Pipeline):
with self.device_placement():
tokens = self.tokenizer.encode_plus(
sentence, return_attention_mask=False,
sentence,
return_attention_mask=False,
return_tensors=self.framework,
max_length=self.tokenizer.max_len
max_length=self.tokenizer.max_len,
)
# Forward
if self.framework == 'tf':
if self.framework == "tf":
entities = self.model(tokens)[0][0].numpy()
input_ids = tokens['input_ids'].numpy()[0]
input_ids = tokens["input_ids"].numpy()[0]
else:
with torch.no_grad():
entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens['input_ids'].cpu().numpy()[0]
input_ids = tokens["input_ids"].cpu().numpy()[0]
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1)
......@@ -505,11 +550,13 @@ class NerPipeline(Pipeline):
answer = []
for idx, label_idx in enumerate(labels_idx):
if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [{
'word': self.tokenizer.decode([int(input_ids[idx])]),
'score': score[idx][label_idx].item(),
'entity': self.model.config.id2label[label_idx]
}]
answer += [
{
"word": self.tokenizer.decode([int(input_ids[idx])]),
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
}
]
# Append
answers += [answer]
......@@ -526,18 +573,19 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied
arguments.
"""
def __call__(self, *args, **kwargs):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if args is not None and len(args) > 0:
if len(args) == 1:
kwargs['X'] = args[0]
kwargs["X"] = args[0]
else:
kwargs['X'] = list(args)
kwargs["X"] = list(args)
# Generic compatibility with sklearn and Keras
# Batched data
if 'X' in kwargs or 'data' in kwargs:
inputs = kwargs['X'] if 'X' in kwargs else kwargs['data']
if "X" in kwargs or "data" in kwargs:
inputs = kwargs["X"] if "X" in kwargs else kwargs["data"]
if isinstance(inputs, dict):
inputs = [inputs]
......@@ -547,28 +595,31 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
for i, item in enumerate(inputs):
if isinstance(item, dict):
if any(k not in item for k in ['question', 'context']):
raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
if any(k not in item for k in ["question", "context"]):
raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
inputs[i] = QuestionAnsweringPipeline.create_sample(**item)
elif not isinstance(item, SquadExample):
raise ValueError(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.format('X' if 'X' in kwargs else 'data')
"{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format(
"X" if "X" in kwargs else "data"
)
)
# Tabular input
elif 'question' in kwargs and 'context' in kwargs:
if isinstance(kwargs['question'], str):
kwargs['question'] = [kwargs['question']]
elif "question" in kwargs and "context" in kwargs:
if isinstance(kwargs["question"], str):
kwargs["question"] = [kwargs["question"]]
if isinstance(kwargs['context'], str):
kwargs['context'] = [kwargs['context']]
if isinstance(kwargs["context"], str):
kwargs["context"] = [kwargs["context"]]
inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
inputs = [
QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"])
]
else:
raise ValueError('Unknown arguments {}'.format(kwargs))
raise ValueError("Unknown arguments {}".format(kwargs))
if not isinstance(inputs, list):
inputs = [inputs]
......@@ -581,22 +632,31 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head.
"""
default_input_names = 'question,context'
def __init__(self, model,
tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard],
framework: Optional[str] = None,
device: int = -1, **kwargs):
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device, **kwargs)
default_input_names = "question,context"
def __init__(
self,
model,
tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard],
framework: Optional[str] = None,
device: int = -1,
**kwargs
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device,
**kwargs
)
@staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
def create_sample(
question: Union[str, List[str]], context: Union[str, List[str]]
) -> Union[SquadExample, List[SquadExample]]:
"""
QuestionAnsweringPipeline leverages the SquadExample/SquadFeatures internally.
This helper method encapsulate all the logic for converting question(s) and context(s) to SquadExample(s).
......@@ -629,26 +689,28 @@ class QuestionAnsweringPipeline(Pipeline):
end: the character index in the original string corresponding to the ending of the answer' span
"""
# Set defaults values
kwargs.setdefault('topk', 1)
kwargs.setdefault('doc_stride', 128)
kwargs.setdefault('max_answer_len', 15)
kwargs.setdefault('max_seq_len', 384)
kwargs.setdefault('max_question_len', 64)
kwargs.setdefault("topk", 1)
kwargs.setdefault("doc_stride", 128)
kwargs.setdefault("max_answer_len", 15)
kwargs.setdefault("max_seq_len", 384)
kwargs.setdefault("max_question_len", 64)
if kwargs['topk'] < 1:
raise ValueError('topk parameter should be >= 1 (got {})'.format(kwargs['topk']))
if kwargs["topk"] < 1:
raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
if kwargs['max_answer_len'] < 1:
raise ValueError('max_answer_len parameter should be >= 1 (got {})'.format(kwargs['max_answer_len']))
if kwargs["max_answer_len"] < 1:
raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
# Convert inputs to features
examples = self._args_parser(*texts, **kwargs)
features = squad_convert_examples_to_features(examples, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False)
features = squad_convert_examples_to_features(
examples, self.tokenizer, kwargs["max_seq_len"], kwargs["doc_stride"], kwargs["max_question_len"], False
)
fw_args = self.inputs_for_model([f.__dict__ for f in features])
# Manage tensor allocation on correct device
with self.device_placement():
if self.framework == 'tf':
if self.framework == "tf":
fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
start, end = self.model(fw_args)
start, end = start.numpy(), end.numpy()
......@@ -672,16 +734,18 @@ class QuestionAnsweringPipeline(Pipeline):
# Mask CLS
start_[0] = end_[0] = 0
starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len'])
starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
char_to_word = np.array(example.char_to_word_offset)
# Convert the answer (tokens) back to the original text
answers += [
{
'score': score.item(),
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]:feature.token_to_orig_map[e] + 1])
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]
......@@ -767,71 +831,71 @@ class QuestionAnsweringPipeline(Pipeline):
chars_idx += len(word) + 1
# Join text with spaces
return {'answer': ' '.join(words), 'start': max(0, char_start_idx), 'end': min(len(text), char_end_idx)}
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
# Register all the supported task here
SUPPORTED_TASKS = {
'feature-extraction': {
'impl': FeatureExtractionPipeline,
'tf': TFAutoModel if is_tf_available() else None,
'pt': AutoModel if is_torch_available() else None,
'default': {
'model': {
'pt': 'distilbert-base-uncased',
'tf': 'distilbert-base-uncased',
},
'config': None,
'tokenizer': 'distilbert-base-uncased'
}
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": TFAutoModel if is_tf_available() else None,
"pt": AutoModel if is_torch_available() else None,
"default": {
"model": {"pt": "distilbert-base-uncased", "tf": "distilbert-base-uncased",},
"config": None,
"tokenizer": "distilbert-base-uncased",
},
},
'sentiment-analysis': {
'impl': TextClassificationPipeline,
'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,
'pt': AutoModelForSequenceClassification if is_torch_available() else None,
'default': {
'model': {
'pt': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin',
'tf': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5',
"sentiment-analysis": {
"impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
"default": {
"model": {
"pt": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin",
"tf": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5",
},
'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json',
'tokenizer': 'distilbert-base-uncased'
}
"config": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json",
"tokenizer": "distilbert-base-uncased",
},
},
'ner': {
'impl': NerPipeline,
'tf': TFAutoModelForTokenClassification if is_tf_available() else None,
'pt': AutoModelForTokenClassification if is_torch_available() else None,
'default': {
'model': {
'pt':'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin',
'tf': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5',
"ner": {
"impl": NerPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
"pt": AutoModelForTokenClassification if is_torch_available() else None,
"default": {
"model": {
"pt": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin",
"tf": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5",
},
'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json',
'tokenizer': 'bert-large-cased'
}
"config": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json",
"tokenizer": "bert-large-cased",
},
},
'question-answering': {
'impl': QuestionAnsweringPipeline,
'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
'pt': AutoModelForQuestionAnswering if is_torch_available() else None,
'default': {
'model': {
'pt': 'distilbert-base-uncased-distilled-squad',
'tf': 'distilbert-base-uncased-distilled-squad',
"question-answering": {
"impl": QuestionAnsweringPipeline,
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
"default": {
"model": {
"pt": "distilbert-base-uncased-distilled-squad",
"tf": "distilbert-base-uncased-distilled-squad",
},
'config': None,
'tokenizer': 'distilbert-base-uncased'
}
}
"config": None,
"tokenizer": "distilbert-base-uncased",
},
},
}
def pipeline(task: str, model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
modelcard: Optional[Union[str, ModelCard]] = None,
**kwargs) -> Pipeline:
def pipeline(
task: str,
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
modelcard: Optional[Union[str, ModelCard]] = None,
**kwargs
) -> Pipeline:
"""
Utility factory method to build a pipeline.
Pipeline are made of:
......@@ -852,11 +916,11 @@ def pipeline(task: str, model: Optional = None,
framework = get_framework(model)
targeted_task = SUPPORTED_TASKS[task]
task, model_class = targeted_task['impl'], targeted_task[framework]
task, model_class = targeted_task["impl"], targeted_task[framework]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
models, config, tokenizer = tuple(targeted_task['default'].values())
models, config, tokenizer = tuple(targeted_task["default"].values())
model = models[framework]
# Try to infer tokenizer from model or config name (if provided as str)
......@@ -867,8 +931,10 @@ def pipeline(task: str, model: Optional = None,
tokenizer = config
else:
# Impossible to guest what is the right tokenizer here
raise Exception("Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer.")
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer."
)
# Try to infer modelcard from model or config name (if provided as str)
if modelcard is None:
......@@ -894,14 +960,18 @@ def pipeline(task: str, model: Optional = None,
if isinstance(model, str):
# Handle transparent TF/PT model conversion
model_kwargs = {}
if framework == 'pt' and model.endswith('.h5'):
model_kwargs['from_tf'] = True
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.')
elif framework == 'tf' and model.endswith('.bin'):
model_kwargs['from_pt'] = True
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.')
if framework == "pt" and model.endswith(".h5"):
model_kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
model_kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)
return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)
......@@ -32,10 +32,10 @@ class ConfigTester(object):
def create_and_test_config_common_properties(self):
config = self.config_class(**self.inputs_dict)
self.parent.assertTrue(hasattr(config, 'vocab_size'))
self.parent.assertTrue(hasattr(config, 'hidden_size'))
self.parent.assertTrue(hasattr(config, 'num_attention_heads'))
self.parent.assertTrue(hasattr(config, 'num_hidden_layers'))
self.parent.assertTrue(hasattr(config, "vocab_size"))
self.parent.assertTrue(hasattr(config, "hidden_size"))
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
self.parent.assertTrue(hasattr(config, "num_hidden_layers"))
def create_and_test_config_to_json_string(self):
config = self.config_class(**self.inputs_dict)
......@@ -68,5 +68,6 @@ class ConfigTester(object):
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
unittest.main()
......@@ -28,20 +28,15 @@ PASS = "__DUMMY_TRANSFORMERS_PASS__"
FILES = [
(
"Test-{}.txt".format(int(time.time())),
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt"
)
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt"),
),
(
"yoyo {}.txt".format(int(time.time())), # space is intentional
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"
)
"yoyo {}.txt".format(int(time.time())), # space is intentional
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
),
]
class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
......@@ -76,11 +71,9 @@ class HfApiEndpointsTest(HfApiCommonTest):
def test_presign_and_upload(self):
for FILE_KEY, FILE_PATH in FILES:
access_url = self._api.presign_and_upload(
token=self._token, filename=FILE_KEY, filepath=FILE_PATH
)
access_url = self._api.presign_and_upload(token=self._token, filename=FILE_KEY, filepath=FILE_PATH)
self.assertIsInstance(access_url, six.string_types)
with open(FILE_PATH, 'r') as f:
with open(FILE_PATH, "r") as f:
body = f.read()
r = requests.get(access_url)
self.assertEqual(r.text, body)
......@@ -93,7 +86,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
self.assertIsInstance(o, S3Obj)
class HfFolderTest(unittest.TestCase):
def test_token_workflow(self):
"""
......@@ -102,18 +94,12 @@ class HfFolderTest(unittest.TestCase):
"""
token = "token-{}".format(int(time.time()))
HfFolder.save_token(token)
self.assertEqual(
HfFolder.get_token(),
token
)
self.assertEqual(HfFolder.get_token(), token)
HfFolder.delete_token()
HfFolder.delete_token()
# ^^ not an error, we test that the
# second call does not fail.
self.assertEqual(
HfFolder.get_token(),
None
)
self.assertEqual(HfFolder.get_token(), None)
if __name__ == "__main__":
......
......@@ -21,44 +21,39 @@ import unittest
from transformers.modelcard import ModelCard
from .tokenization_tests_commons import TemporaryDirectory
class ModelCardTester(unittest.TestCase):
class ModelCardTester(unittest.TestCase):
def setUp(self):
self.inputs_dict = {'model_details': {
'Organization': 'testing',
'Model date': 'today',
'Model version': 'v2.1, Developed by Test Corp in 2019.',
'Architecture': 'Convolutional Neural Network.',
},
'metrics': 'BLEU and ROUGE-1',
'evaluation_data':{
'Datasets':{
'BLEU': 'My-great-dataset-v1',
'ROUGE-1': 'My-short-dataset-v2.1',
},
'Preprocessing': 'See details on https://arxiv.org/pdf/1810.03993.pdf'
},
'training_data':{
'Dataset': 'English Wikipedia dump dated 2018-12-01',
'Preprocessing': 'Using SentencePiece vocabulary of size 52k tokens. See details on https://arxiv.org/pdf/1810.03993.pdf'
},
'quantitative_analyses': {
'BLEU': 55.1,
'ROUGE-1': 76,
},
}
self.inputs_dict = {
"model_details": {
"Organization": "testing",
"Model date": "today",
"Model version": "v2.1, Developed by Test Corp in 2019.",
"Architecture": "Convolutional Neural Network.",
},
"metrics": "BLEU and ROUGE-1",
"evaluation_data": {
"Datasets": {"BLEU": "My-great-dataset-v1", "ROUGE-1": "My-short-dataset-v2.1",},
"Preprocessing": "See details on https://arxiv.org/pdf/1810.03993.pdf",
},
"training_data": {
"Dataset": "English Wikipedia dump dated 2018-12-01",
"Preprocessing": "Using SentencePiece vocabulary of size 52k tokens. See details on https://arxiv.org/pdf/1810.03993.pdf",
},
"quantitative_analyses": {"BLEU": 55.1, "ROUGE-1": 76,},
}
def test_model_card_common_properties(self):
modelcard = ModelCard.from_dict(self.inputs_dict)
self.assertTrue(hasattr(modelcard, 'model_details'))
self.assertTrue(hasattr(modelcard, 'intended_use'))
self.assertTrue(hasattr(modelcard, 'factors'))
self.assertTrue(hasattr(modelcard, 'metrics'))
self.assertTrue(hasattr(modelcard, 'evaluation_data'))
self.assertTrue(hasattr(modelcard, 'training_data'))
self.assertTrue(hasattr(modelcard, 'quantitative_analyses'))
self.assertTrue(hasattr(modelcard, 'ethical_considerations'))
self.assertTrue(hasattr(modelcard, 'caveats_and_recommendations'))
self.assertTrue(hasattr(modelcard, "model_details"))
self.assertTrue(hasattr(modelcard, "intended_use"))
self.assertTrue(hasattr(modelcard, "factors"))
self.assertTrue(hasattr(modelcard, "metrics"))
self.assertTrue(hasattr(modelcard, "evaluation_data"))
self.assertTrue(hasattr(modelcard, "training_data"))
self.assertTrue(hasattr(modelcard, "quantitative_analyses"))
self.assertTrue(hasattr(modelcard, "ethical_considerations"))
self.assertTrue(hasattr(modelcard, "caveats_and_recommendations"))
def test_model_card_to_json_string(self):
modelcard = ModelCard.from_dict(self.inputs_dict)
......@@ -70,7 +65,7 @@ class ModelCardTester(unittest.TestCase):
model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"modelcard.json")
filename = os.path.join(tmpdirname, "modelcard.json")
model_card_first.to_json_file(filename)
model_card_second = ModelCard.from_json_file(filename)
......@@ -85,5 +80,6 @@ class ModelCardTester(unittest.TestCase):
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()
......@@ -20,14 +20,18 @@ import unittest
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .modeling_common_test import CommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
AlbertForSequenceClassification, AlbertForQuestionAnswering,
)
from transformers import (
AlbertConfig,
AlbertModel,
AlbertForMaskedLM,
AlbertForSequenceClassification,
AlbertForQuestionAnswering,
)
from transformers.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
......@@ -37,33 +41,33 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
class AlbertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
embedding_size=16,
hidden_size=36,
num_hidden_layers=6,
num_hidden_groups=6,
num_attention_heads=6,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
embedding_size=16,
hidden_size=36,
num_hidden_layers=6,
num_hidden_groups=6,
num_attention_heads=6,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
......@@ -120,16 +124,17 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
num_hidden_groups=self.num_hidden_groups)
num_hidden_groups=self.num_hidden_groups,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_albert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_albert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = AlbertModel(config=config)
model.to(torch_device)
model.eval()
......@@ -142,66 +147,79 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_albert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = AlbertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_albert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_albert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = AlbertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
start_positions=sequence_labels, end_positions=sequence_labels)
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_albert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_albert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = AlbertForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
def setUp(self):
......@@ -233,5 +251,6 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
......@@ -25,14 +25,21 @@ from transformers import is_torch_available
from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER
if is_torch_available():
from transformers import (AutoConfig, BertConfig,
AutoModel, BertModel,
AutoModelWithLMHead, BertForMaskedLM,
AutoModelForSequenceClassification, BertForSequenceClassification,
AutoModelForQuestionAnswering, BertForQuestionAnswering)
from transformers import (
AutoConfig,
BertConfig,
AutoModel,
BertModel,
AutoModelWithLMHead,
BertForMaskedLM,
AutoModelForSequenceClassification,
BertForSequenceClassification,
AutoModelForQuestionAnswering,
BertForQuestionAnswering,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .modeling_common_test import CommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester
......@@ -75,7 +82,9 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(config, BertConfig)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True)
model, loading_info = AutoModelForSequenceClassification.from_pretrained(
model_name, output_loading_info=True
)
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForSequenceClassification)
......
......@@ -20,51 +20,68 @@ import unittest
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .modeling_common_test import CommonTestCases, ids_tensor, floats_tensor
from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from transformers import (
BertConfig,
BertModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertForMultipleChoice,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
@require_torch
class BertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) if is_torch_available() else ()
all_model_classes = (
(
BertModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
if is_torch_available()
else ()
)
class BertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
......@@ -119,25 +136,44 @@ class BertModelTest(CommonTestCases.CommonModelTester):
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range)
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertModel(config=config)
model.to(torch_device)
model.eval()
......@@ -150,16 +186,38 @@ class BertModelTest(CommonTestCases.CommonModelTester):
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_model_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask):
def create_and_check_bert_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = BertModel(config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
sequence_output, pooled_output = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output, pooled_output = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = {
......@@ -167,122 +225,171 @@ class BertModelTest(CommonTestCases.CommonModelTester):
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_bert_model_for_masked_lm_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask):
def create_and_check_bert_model_for_masked_lm_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = BertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask)
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states)
loss, prediction_scores = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
loss, prediction_scores = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
encoder_hidden_states=encoder_hidden_states,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForNextSentencePrediction(config=config)
model.to(torch_device)
model.eval()
loss, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels)
loss, seq_relationship_score = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForPreTraining(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
loss, prediction_scores, seq_relationship_score = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
start_positions=sequence_labels, end_positions=sequence_labels)
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = BertForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = BertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
self.check_loss_output(result)
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = BertForMultipleChoice(config=config)
model.to(torch_device)
......@@ -290,24 +397,31 @@ class BertModelTest(CommonTestCases.CommonModelTester):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels)
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
def setUp(self):
......
......@@ -36,34 +36,48 @@ if is_torch_available():
import torch
import numpy as np
from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from transformers import (
AdaptiveEmbedding,
PretrainedConfig,
PreTrainedModel,
BertModel,
BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel,
GPT2Config,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
)
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if '_range' in key or '_std' in key or 'initializer_factor' in key:
if "_range" in key or "_std" in key or "initializer_factor" in key:
setattr(configs_no_init, key, 0.0)
return configs_no_init
class CommonTestCases:
class CommonTestCases:
@require_torch
class CommonModelTester(unittest.TestCase):
......@@ -108,8 +122,11 @@ class CommonTestCases:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(param.data.mean().item(), [0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
self.assertIn(
param.data.mean().item(),
[0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -131,10 +148,22 @@ class CommonTestCases:
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_seq_length = self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length
encoder_seq_length = self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length
decoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else decoder_seq_length
encoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else encoder_seq_length
decoder_seq_length = (
self.model_tester.decoder_seq_length
if hasattr(self.model_tester, "decoder_seq_length")
else self.model_tester.seq_length
)
encoder_seq_length = (
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length
)
decoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
)
encoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
)
for model_class in self.all_model_classes:
config.output_attentions = True
......@@ -150,23 +179,20 @@ class CommonTestCases:
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
encoder_seq_length ,
encoder_key_length])
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs[(out_len // 2)-1]
decoder_attentions = outputs[(out_len // 2) - 1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
decoder_seq_length,
decoder_key_length
])
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# Check attention is always last and order is fine
config.output_attentions = True
......@@ -184,9 +210,8 @@ class CommonTestCases:
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
encoder_seq_length,
encoder_key_length])
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -215,7 +240,7 @@ class CommonTestCases:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
inputs = inputs_dict["input_ids"] # Let's keep only input_ids
try:
traced_gpt2 = torch.jit.trace(model, inputs)
......@@ -269,12 +294,14 @@ class CommonTestCases:
# Prepare head_mask
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask = torch.ones(self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device)
head_mask = torch.ones(
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
head_mask.requires_grad_(requires_grad=True)
inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask
inputs["head_mask"] = head_mask
outputs = model(**inputs)
......@@ -289,21 +316,20 @@ class CommonTestCases:
# Remove Nan
for t in attentions:
self.assertLess(torch.sum(torch.isnan(t)), t.numel() / 4) # Check we don't have more than 25% nans (arbitrary)
attentions = [t.masked_fill(torch.isnan(t), 0.0) for t in attentions] # remove them (the test is less complete)
self.assertLess(
torch.sum(torch.isnan(t)), t.numel() / 4
) # Check we don't have more than 25% nans (arbitrary)
attentions = [
t.masked_fill(torch.isnan(t), 0.0) for t in attentions
] # remove them (the test is less complete)
self.assertIsNotNone(multihead_outputs)
self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
self.assertAlmostEqual(
attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(
attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(
attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
self.assertAlmostEqual(
attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(
attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
def test_head_pruning(self):
if not self.test_pruning:
......@@ -320,20 +346,16 @@ class CommonTestCases:
model = model_class(config=config)
model.to(torch_device)
model.eval()
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
model.prune_heads(heads_to_prune)
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(
attentions[0].shape[-3], 1)
self.assertEqual(
attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(
attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
self.assertEqual(attentions[0].shape[-3], 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_save_load_from_pretrained(self):
if not self.test_pruning:
......@@ -350,8 +372,7 @@ class CommonTestCases:
model = model_class(config=config)
model.to(torch_device)
model.eval()
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
model.prune_heads(heads_to_prune)
with TemporaryDirectory() as temp_dir_name:
......@@ -366,7 +387,6 @@ class CommonTestCases:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_save_load_from_config_init(self):
if not self.test_pruning:
return
......@@ -380,8 +400,7 @@ class CommonTestCases:
config.output_attentions = True
config.output_hidden_states = False
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
config.pruned_heads = heads_to_prune
model = model_class(config=config)
......@@ -446,7 +465,7 @@ class CommonTestCases:
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads -1)
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
......@@ -470,8 +489,13 @@ class CommonTestCases:
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length,
self.model_tester.hidden_size])
[
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length,
self.model_tester.hidden_size,
],
)
def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -512,15 +536,10 @@ class CommonTestCases:
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(
model.get_input_embeddings(),
(torch.nn.Embedding, AdaptiveEmbedding)
)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(torch.nn.Embedding(10, 10))
x = model.get_output_embeddings()
self.assertTrue(
x is None or isinstance(x, torch.nn.Linear)
)
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
def test_tie_model_weights(self):
if not self.test_torchscript:
......@@ -602,30 +621,30 @@ class CommonTestCases:
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_position_ids=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_positions=33,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
n_choices=3,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None,
config_class=None,
base_model_class=None,
lm_head_model_class=None,
double_head_model_class=None,
):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_position_ids=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_positions=33,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
n_choices=3,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None,
config_class=None,
base_model_class=None,
lm_head_model_class=None,
double_head_model_class=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
......@@ -676,13 +695,14 @@ class CommonTestCases:
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
initializer_range=self.initializer_range)
initializer_range=self.initializer_range,
)
return (config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids)
return (config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids)
def create_and_check_base_model(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
def create_and_check_base_model(
self, config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids
):
model = self.base_model_class(config)
model.to(torch_device)
model.eval()
......@@ -694,12 +714,12 @@ class CommonTestCases:
hidden_state = outputs[0]
self.parent.assertListEqual(
list(hidden_state.size()),
[self.batch_size, self.n_choices, self.seq_length, self.hidden_size])
list(hidden_state.size()), [self.batch_size, self.n_choices, self.seq_length, self.hidden_size]
)
def create_and_check_lm_head(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
def create_and_check_lm_head(
self, config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids
):
model = self.lm_head_model_class(config)
model.to(torch_device)
model.eval()
......@@ -709,14 +729,13 @@ class CommonTestCases:
total_voc = self.vocab_size
self.parent.assertListEqual(
list(lm_logits.size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(loss.size()),
[])
list(lm_logits.size()), [self.batch_size, self.n_choices, self.seq_length, total_voc]
)
self.parent.assertListEqual(list(loss.size()), [])
def create_and_check_presents(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
def create_and_check_presents(
self, config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids
):
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
......@@ -727,30 +746,39 @@ class CommonTestCases:
self.parent.assertEqual(self.num_hidden_layers, len(presents))
self.parent.assertListEqual(
list(presents[0].size()),
[2, self.batch_size * self.n_choices, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
[
2,
self.batch_size * self.n_choices,
self.num_attention_heads,
self.seq_length,
self.hidden_size // self.num_attention_heads,
],
)
def create_and_check_double_heads(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
def create_and_check_double_heads(
self, config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids
):
model = self.double_head_model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
outputs = model(
input_ids,
mc_token_ids,
lm_labels=lm_labels,
mc_labels=mc_labels,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
loss = [lm_loss, mc_loss]
total_voc = self.vocab_size
self.parent.assertListEqual(
list(lm_logits.size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(mc_logits.size()),
[self.batch_size, self.n_choices])
self.parent.assertListEqual(
[list(l.size()) for l in loss],
[[], []])
list(lm_logits.size()), [self.batch_size, self.n_choices, self.seq_length, total_voc]
)
self.parent.assertListEqual(list(mc_logits.size()), [self.batch_size, self.n_choices])
self.parent.assertListEqual([list(l.size()) for l in loss], [[], []])
def create_and_check_model_from_pretrained(self):
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
......@@ -759,9 +787,8 @@ class CommonTestCases:
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids) = config_and_inputs
inputs_dict = {'input_ids': input_ids}
(config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids) = config_and_inputs
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
def run_common_tests(self, test_presents=False):
......@@ -791,10 +818,10 @@ class ConfigTester(object):
def create_and_test_config_common_properties(self):
config = self.config_class(**self.inputs_dict)
self.parent.assertTrue(hasattr(config, 'vocab_size'))
self.parent.assertTrue(hasattr(config, 'hidden_size'))
self.parent.assertTrue(hasattr(config, 'num_attention_heads'))
self.parent.assertTrue(hasattr(config, 'num_hidden_layers'))
self.parent.assertTrue(hasattr(config, "vocab_size"))
self.parent.assertTrue(hasattr(config, "hidden_size"))
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
self.parent.assertTrue(hasattr(config, "num_hidden_layers"))
def create_and_test_config_to_json_string(self):
config = self.config_class(**self.inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment