Unverified Commit 21637d49 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into do_lower_case

parents 7246d3c2 de2696f6
...@@ -220,7 +220,8 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs: ...@@ -220,7 +220,8 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -252,7 +253,8 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -252,7 +253,8 @@ class CTRLModel(CTRLPreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -437,7 +439,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -437,7 +439,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
......
...@@ -30,6 +30,7 @@ import numpy as np ...@@ -30,6 +30,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss
from .modeling_utils import PreTrainedModel, prune_linear_layer from .modeling_utils import PreTrainedModel, prune_linear_layer
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
...@@ -702,3 +703,75 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -702,3 +703,75 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
outputs = (total_loss,) + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings("""DistilBert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
DISTILBERT_START_DOCSTRING,
DISTILBERT_INPUTS_DOCSTRING)
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
def __init__(self, config):
super(DistilBertForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
self.distilbert = DistilBertModel(config)
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, head_mask=None,
inputs_embeds=None, labels=None):
outputs = self.distilbert(input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
...@@ -298,7 +298,8 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -298,7 +298,8 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -330,7 +331,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -330,7 +331,8 @@ class GPT2Model(GPT2PreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -503,7 +505,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -503,7 +505,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -595,7 +598,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -595,7 +598,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
......
This diff is collapsed.
...@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
def _embedding(self, inputs, training=False): def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
seq_length = tf.shape(input_ids)[1] if input_ids is not None:
input_shape = tf.shape(input_ids)
else:
input_shape = tf.shape(inputs_embeds)[:-1]
seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(input_shape, 0)
words_embeddings = tf.gather(self.word_embeddings, input_ids) if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
...@@ -460,6 +466,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -460,6 +466,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.encoder = TFBertEncoder(config, name='encoder') self.encoder = TFBertEncoder(config, name='encoder')
self.pooler = TFBertPooler(config, name='pooler') self.pooler = TFBertPooler(config, name='pooler')
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -470,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -470,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1) attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
...@@ -520,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -520,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training) embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
...@@ -702,6 +722,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -702,6 +722,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.nsp = TFBertNSPHead(config, name='nsp___cls') self.nsp = TFBertNSPHead(config, name='nsp___cls')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def get_output_embeddings(self):
return self.bert.embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -747,6 +770,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -747,6 +770,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def get_output_embeddings(self):
return self.bert.embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -892,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -892,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name='classifier') name='classifier')
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
num_choices = tf.shape(input_ids)[1] if input_ids is not None:
seq_length = tf.shape(input_ids)[2] num_choices = tf.shape(input_ids)[1]
seq_length = tf.shape(input_ids)[2]
else:
num_choices = tf.shape(inputs_embeds)[1]
seq_length = tf.shape(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask] flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
outputs = self.bert(flat_inputs, training=training) outputs = self.bert(flat_inputs, training=training)
......
...@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
name='h_._{}'.format(i)) for i in range(config.n_layer)] name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm") self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
def get_input_embeddings(self):
return self.w
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -201,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -201,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -209,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -209,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', past) past = inputs.get('past', past)
...@@ -217,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -217,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
input_shape = shape_list(input_ids) if input_ids is not None and inputs_embeds is not None:
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -230,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -230,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(past[0][0])[-2]
if position_ids is None: if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1]) position_ids = tf.tile(position_ids, [input_shape[0], 1])
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -270,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -270,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_embeds = 0 token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.w(input_ids, mode='embedding') if inputs_embeds is None:
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded inputs_embeds = self.w(input_ids, mode='embedding')
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
...@@ -480,6 +492,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel): ...@@ -480,6 +492,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.input_embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range)) initializer=get_initializer(self.initializer_range))
super(TFEmbeddings, self).build(input_shape) super(TFEmbeddings, self).build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, training=training) return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(inputs)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False): def _embedding(self, inputs, inputs_embeds=None, training=False):
""" """
Parameters Parameters
---------- ----------
...@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer):
else: else:
input_ids, position_ids = inputs input_ids, position_ids = inputs
seq_length = tf.shape(input_ids)[1] if input_ids is not None:
seq_length = tf.shape(input_ids)[1]
else:
seq_length = tf.shape(inputs_embeds)[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
word_embeddings = tf.gather(self.word_embeddings, input_ids) if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim) embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim)
return embeddings return embeddings
...@@ -398,28 +403,42 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -398,28 +403,42 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
self.transformer = TFTransformer(config, name="transformer") # Encoder self.transformer = TFTransformer(config, name="transformer") # Encoder
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs." inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 4, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length) attention_mask = tf.ones(input_shape) # (bs, seq_length)
attention_mask = tf.cast(attention_mask, dtype=tf.float32) attention_mask = tf.cast(attention_mask, dtype=tf.float32)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -432,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -432,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
else: else:
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim) embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training) tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training)
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
...@@ -613,6 +632,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): ...@@ -613,6 +632,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def get_output_embeddings(self):
return self.vocab_projector.input_embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, **kwargs) distilbert_output = self.distilbert(inputs, **kwargs)
......
...@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
name='h_._{}'.format(i)) for i in range(config.n_layer)] name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f') self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def get_input_embeddings(self):
return self.wte
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -228,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -228,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -236,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -236,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', past) past = inputs.get('past', past)
...@@ -244,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -244,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if past is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) past = [None] * len(self.h)
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(past[0][0])[-2]
if position_ids is None: if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
if attention_mask is not None: if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -286,11 +301,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -286,11 +301,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.wte(input_ids, mode='embedding') if inputs_embeds is None:
inputs_embeds = self.wte(input_ids, mode='embedding')
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
...@@ -490,6 +504,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): ...@@ -490,6 +504,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
def get_output_embeddings(self):
return self.transformer.wte
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -560,7 +577,10 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -560,7 +577,10 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False): def get_output_embeddings(self):
return self.transformer.wte
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -568,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -568,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs." mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', past) past = inputs.get('past', past)
...@@ -577,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -577,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids) mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 7, "Too many inputs." assert len(inputs) <= 8, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
input_shapes = shape_list(input_ids) if input_ids is not None:
input_shapes = shape_list(input_ids)
else:
input_shapes = shape_list(inputs_embeds)[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask] flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
scale=True, scale=True,
name='h_._{}'.format(i)) for i in range(config.n_layer)] name='h_._{}'.format(i)) for i in range(config.n_layer)]
def get_input_embeddings(self):
return self.tokens_embed
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -226,26 +229,38 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -226,26 +229,38 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None: if position_ids is None:
position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
if attention_mask is not None: if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -277,11 +292,10 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -277,11 +292,10 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.tokens_embed(input_ids, mode='embedding') if inputs_embeds is None:
inputs_embeds = self.tokens_embed(input_ids, mode='embedding')
position_embeds = self.positions_embed(position_ids) position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
...@@ -462,6 +476,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel): ...@@ -462,6 +476,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFOpenAIGPTLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
def get_output_embeddings(self):
return self.transformer.tokens_embed
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -524,36 +541,44 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -524,36 +541,44 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False): def get_output_embeddings(self):
return self.transformer.tokens_embed
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
mc_token_ids = inputs[5] if len(inputs) > 5 else mc_token_ids inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids) mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
input_shapes = shape_list(input_ids) if input_ids is not None:
input_shapes = shape_list(input_ids)
else:
input_shapes = shape_list(inputs_embeds)[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask] flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
def _embedding(self, inputs, training=False): def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
if input_ids is not None:
seq_length = tf.shape(input_ids)[1]
else:
seq_length = tf.shape(inputs_embeds)[1]
seq_length = tf.shape(input_ids)[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids], training=training) return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
class TFRobertaMainLayer(TFBertMainLayer): class TFRobertaMainLayer(TFBertMainLayer):
...@@ -65,6 +69,9 @@ class TFRobertaMainLayer(TFBertMainLayer): ...@@ -65,6 +69,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
super(TFRobertaMainLayer, self).__init__(config, **kwargs) super(TFRobertaMainLayer, self).__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name='embeddings') self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
def get_input_embeddings(self):
return self.embeddings
class TFRobertaPreTrainedModel(TFPreTrainedModel): class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
...@@ -280,6 +287,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel): ...@@ -280,6 +287,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta") self.roberta = TFRobertaMainLayer(config, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.decoder
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs) outputs = self.roberta(inputs, **kwargs)
......
...@@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
name='r_r_bias') name='r_r_bias')
super(TFTransfoXLMainLayer, self).build(input_shape) super(TFTransfoXLMainLayer, self).build(input_shape)
def get_input_embeddings(self):
return self.word_emb
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb return self.word_emb
...@@ -427,11 +430,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -427,11 +430,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads): def _prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
def init_mems(self, data): def init_mems(self, bsz):
if self.mem_len > 0: if self.mem_len > 0:
mems = [] mems = []
for i in range(self.n_layer): for i in range(self.n_layer):
empty = tf.zeros([self.mem_len, shape_list(data)[1], self.d_model]) empty = tf.zeros([self.mem_len, bsz, self.d_model])
mems.append(empty) mems.append(empty)
return mems return mems
...@@ -461,28 +464,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -461,28 +464,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return new_mems return new_mems
def call(self, inputs, mems=None, head_mask=None, training=False): def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs." inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
mems = inputs.get('mems', mems) mems = inputs.get('mems', mems)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 4, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz] # so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = tf.transpose(input_ids, perm=(1, 0)) if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if mems is None: if mems is None:
mems = self.init_mems(input_ids) mems = self.init_mems(bsz)
qlen, bsz = shape_list(input_ids)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -494,7 +506,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -494,7 +506,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
else: else:
head_mask = [None] * self.n_layer head_mask = [None] * self.n_layer
word_emb = self.word_emb(input_ids) if inputs_embeds is not None:
word_emb = inputs_embeds
else:
word_emb = self.word_emb(input_ids)
mlen = shape_list(mems[0])[0] if mems is not None else 0 mlen = shape_list(mems[0])[0] if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
...@@ -720,28 +735,33 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -720,28 +735,33 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) self.transformer.reset_length(tgt_len, ext_len, mem_len)
def init_mems(self, data): def init_mems(self, bsz):
return self.transformer.init_mems(data) return self.transformer.init_mems(bsz)
def call(self, inputs, mems=None, head_mask=None, labels=None, training=False): def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, labels=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
labels = inputs[3] if len(inputs) > 3 else labels inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs." labels = inputs[4] if len(inputs) > 4 else labels
assert len(inputs) <= 5, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
mems = inputs.get('mems', mems) mems = inputs.get('mems', mems)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
labels = inputs.get('labels', labels) labels = inputs.get('labels', labels)
assert len(inputs) <= 4, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
bsz, tgt_len = shape_list(input_ids)[:2] if input_ids is not None:
bsz, tgt_len = shape_list(input_ids)[:2]
else:
bsz, tgt_len = shape_list(inputs_embeds)[:2]
transformer_outputs = self.transformer([input_ids, mems, head_mask], training=training) transformer_outputs = self.transformer([input_ids, mems, head_mask, inputs_embeds], training=training)
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
......
...@@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
# Save config in model # Save config in model
self.config = config self.config = config
def get_input_embeddings(self):
""" Get model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Variable from a provided token Embedding Module. """ Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end Increasing the size will add newly initialized vectors at the end
...@@ -477,10 +492,10 @@ def shape_list(x): ...@@ -477,10 +492,10 @@ def shape_list(x):
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range=0.02): def get_initializer(initializer_range=0.02):
"""Creates a `tf.initializers.truncated_normal` with the given range. """Creates a `tf.initializers.truncated_normal` with the given range.
Args: Args:
initializer_range: float, initializer range for stddev. initializer_range: float, initializer range for stddev.
Returns: Returns:
TruncatedNormal initializer with stddev = `initializer_range`. TruncatedNormal initializer with stddev = `initializer_range`.
""" """
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
...@@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self.prune_heads({int(layer): list(map(int, heads))}) self.prune_heads({int(layer): list(map(int, heads))})
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -288,7 +291,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -288,7 +291,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None, def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
position_ids=None, lengths=None, cache=None, head_mask=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None,
training=False): # removed: src_enc=None, src_len=None training=False): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
...@@ -299,7 +302,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -299,7 +302,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
lengths = inputs[5] if len(inputs) > 5 else lengths lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
assert len(inputs) <= 8, "Too many inputs." inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
...@@ -309,16 +313,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -309,16 +313,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
lengths = inputs.get('lengths', lengths) lengths = inputs.get('lengths', lengths)
cache = inputs.get('cache', cache) cache = inputs.get('cache', cache)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 8, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
bs, slen = shape_list(input_ids)
elif inputs_embeds is not None:
bs, slen = shape_list(inputs_embeds)[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None: if lengths is None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1) if input_ids is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
else:
lengths = tf.convert_to_tensor([slen]*bs, tf.int32)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
bs, slen = shape_list(input_ids)
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs) tf.debugging.assert_equal(shape_list(lengths)[0], bs)
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
...@@ -358,7 +374,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -358,7 +374,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.n_layers head_mask = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None: if cache is not None and input_ids is not None:
_slen = slen - cache['slen'] _slen = slen - cache['slen']
input_ids = input_ids[:, -_slen:] input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:] position_ids = position_ids[:, -_slen:]
...@@ -368,8 +384,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -368,8 +384,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
tensor = self.embeddings(input_ids) if inputs_embeds is None:
tensor = tensor + self.position_embeddings(position_ids) inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids)
if langs is not None and self.use_lang_emb: if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
...@@ -641,6 +659,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -641,6 +659,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
self.transformer = TFXLMMainLayer(config, name='transformer') self.transformer = TFXLMMainLayer(config, name='transformer')
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj') self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
def get_output_embeddings(self):
return self.pred_layer.input_embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
......
...@@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)] self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_input_embeddings(self):
return self.word_embedding
def build(self, input_shape): def build(self, input_shape):
initializer = get_initializer(self.initializer_range) initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(shape=(1, 1, self.d_model), self.mask_emb = self.add_weight(shape=(1, 1, self.d_model),
...@@ -484,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -484,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb return pos_emb
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, training=False): token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
...@@ -494,7 +497,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -494,7 +497,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
assert len(inputs) <= 8, "Too many inputs." inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
...@@ -504,7 +508,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -504,7 +508,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
input_mask = inputs.get('input_mask', input_mask) input_mask = inputs.get('input_mask', input_mask)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 8, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -512,14 +517,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -512,14 +517,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
input_ids = tf.transpose(input_ids, perm=(1, 0)) if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)[:2]
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
qlen, bsz = shape_list(input_ids)[:2]
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0 mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
...@@ -570,7 +584,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -570,7 +584,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask = None non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states ##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(input_ids) if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training) output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None: if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
...@@ -854,6 +871,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -854,6 +871,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
self.transformer = TFXLNetMainLayer(config, name='transformer') self.transformer = TFXLNetMainLayer(config, name='transformer')
self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss') self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name='lm_loss')
def get_output_embeddings(self):
return self.lm_loss.input_embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_state = transformer_outputs[0] hidden_state = transformer_outputs[0]
......
...@@ -315,6 +315,10 @@ class PreTrainedModel(nn.Module): ...@@ -315,6 +315,10 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.")
config = kwargs.pop('config', None) config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None) state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
......
...@@ -23,90 +23,66 @@ from torch.optim.lr_scheduler import LambdaLR ...@@ -23,90 +23,66 @@ from torch.optim.lr_scheduler import LambdaLR
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
""" Constant learning rate schedule. def get_constant_schedule(optimizer, last_epoch=-1):
""" Create a schedule with a constant learning rate.
""" """
def __init__(self, optimizer, last_epoch=-1): return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
class WarmupConstantSchedule(LambdaLR): def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
""" Linear warmup and then constant. """ Create a schedule with a constant learning rate preceded by a warmup
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. period during which the learning rate increases linearly between 0 and 1.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Keeps multiplicative variable equal to 1. after warmup_steps.
""" """
def __init__(self, optimizer, warmup_steps, last_epoch=-1): def lr_lambda(current_step):
self.warmup_steps = warmup_steps if current_step < num_warmup_steps:
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) return float(current_step) / float(max(1.0, num_warmup_steps))
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
return 1. return 1.
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
class WarmupLinearSchedule(LambdaLR):
""" Linear warmup and then linear decay. def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. """ Create a schedule with a learning rate that decreases linearly after
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. linearly increasing during a warmup period.
Linearly decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
class WarmupCosineSchedule(LambdaLR):
""" Linear warmup and then cosine decay.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
Decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, then the multiplicative variable follows cosine function after warmup.
""" """
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): def lr_lambda(current_step):
self.warmup_steps = warmup_steps if current_step < num_warmup_steps:
self.t_total = t_total return float(current_step) / float(max(1, num_warmup_steps))
self.cycles = cycles return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps)) def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1):
# progress after warmup """ Create a schedule with a learning rate that decreases following the
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) values of the cosine function between 0 and `pi * cycles` after a warmup
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) period during which it increases linearly between 0 and 1.
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
""" Linear warmup and then cosine cycles with hard restarts.
Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step.
Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
""" """
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): def lr_lambda(current_step):
self.warmup_steps = warmup_steps if current_step < num_warmup_steps:
self.t_total = t_total return float(current_step) / float(max(1, num_warmup_steps))
self.cycles = cycles progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress)))
def lr_lambda(self, step): return LambdaLR(optimizer, lr_lambda, last_epoch)
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
# progress after warmup
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1):
""" Create a schedule with a learning rate that decreases following the
values of the cosine function with several hard restarts, after a warmup
period during which it increases linearly between 0 and 1.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.:
return 0.
return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.))))
return LambdaLR(optimizer, lr_lambda, last_epoch)
class AdamW(Optimizer): class AdamW(Optimizer):
""" Implements Adam algorithm with weight decay fix. """ Implements Adam algorithm with weight decay fix.
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
AlbertForSequenceClassification, AlbertForQuestionAnswering,
)
from transformers.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
class AlbertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
class AlbertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
embedding_size=16,
hidden_size=36,
num_hidden_layers=6,
num_hidden_groups=6,
num_attention_heads=6,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.num_hidden_groups = num_hidden_groups
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = AlbertConfig(
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
num_hidden_groups=self.num_hidden_groups)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_albert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = AlbertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = AlbertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_albert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = AlbertForQuestionAnswering(config=config)
model.eval()
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
start_positions=sequence_labels, end_positions=sequence_labels)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_albert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = AlbertForSequenceClassification(config)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = AlbertModelTest.AlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_albert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_sequence_classification(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -35,7 +35,7 @@ if is_torch_available(): ...@@ -35,7 +35,7 @@ if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel, from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -468,9 +468,15 @@ class CommonTestCases: ...@@ -468,9 +468,15 @@ class CommonTestCases:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
model.get_input_embeddings() self.assertIsInstance(
model.get_input_embeddings(),
(torch.nn.Embedding, AdaptiveEmbedding)
)
model.set_input_embeddings(torch.nn.Embedding(10, 10)) model.set_input_embeddings(torch.nn.Embedding(10, 10))
model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(
x is None or isinstance(x, torch.nn.Linear)
)
def test_tie_model_weights(self): def test_tie_model_weights(self):
if not self.test_torchscript: if not self.test_torchscript:
......
...@@ -23,6 +23,7 @@ from transformers import is_torch_available ...@@ -23,6 +23,7 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForTokenClassification,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification) DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -180,6 +181,21 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -180,6 +181,21 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.num_labels]) [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_distilbert_for_token_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = DistilBertForTokenClassification(config=config)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
...@@ -209,6 +225,10 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -209,6 +225,10 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
# @pytest.mark.slow # @pytest.mark.slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/" # cache_dir = "/tmp/transformers_test/"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment