Commit 34f28b2a authored by thomwolf's avatar thomwolf
Browse files

WIP GPT2

parent ad88563b
......@@ -684,13 +684,13 @@ class TFBertModel(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertModel.from_pretrained('bert-base-uncased')
input_ids = tf.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(TFBertModel, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertModel, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
@tf.function
......@@ -739,8 +739,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
prediction_scores, seq_relationship_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForPreTraining, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForPreTraining, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
......@@ -790,8 +790,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
loss, prediction_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForMaskedLM, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
......@@ -839,8 +839,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
seq_relationship_scores = outputs[0]
"""
def __init__(self, config):
super(TFBertForNextSentencePrediction, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForNextSentencePrediction, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
......@@ -891,8 +891,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(TFBertForSequenceClassification, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
......@@ -984,8 +984,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
loss, classification_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForMultipleChoice, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForMultipleChoice, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
......@@ -1066,8 +1066,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
loss, scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForTokenClassification, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForTokenClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
......@@ -1128,8 +1128,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(TFBertForQuestionAnswering, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFBertForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
......
......@@ -128,8 +128,8 @@ class TFAttention(tf.keras.layers.Layer):
self.split_size = n_state
self.scale = scale
self.c_attn = TFConv1D(n_state * 3, nx)
self.c_proj = TFConv1D(n_state, nx)
self.c_attn = TFConv1D(n_state * 3, nx, name='c_attn')
self.c_proj = TFConv1D(n_state, nx, name='c_proj')
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.pruned_heads = set()
......@@ -139,7 +139,7 @@ class TFAttention(tf.keras.layers.Layer):
@staticmethod
@tf.function
def attention_mask(nd, ns, *, dtype):
def attention_mask(nd, ns, dtype):
"""1's in the lower triangle, counting from the lower right corner.
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
"""
......@@ -164,7 +164,8 @@ class TFAttention(tf.keras.layers.Layer):
w = w * b - 1e4 * (1 - b)
w = tf.nn.softmax(w)
w = self.attn_dropout(w, training=training)
if training:
w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
......@@ -204,54 +205,238 @@ class TFAttention(tf.keras.layers.Layer):
value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=1)
attn_outputs = self._attn(query, key, value, head_mask)
attn_outputs = self._attn(query, key, value, head_mask, training=training)
a = attn_outputs[0]
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a, training=training)
if training:
a = self.resid_dropout(a)
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
class TFMLP(nn.Module):
def __init__(self, n_state, config, **kwargs):
super(TFMLP, self).__init__(**kwargs)
nx = config.n_embd
self.c_fc = TFConv1D(n_state, nx)
self.c_proj = TFConv1D(nx, n_state)
self.c_fc = TFConv1D(n_state, nx, name='c_fc')
self.c_proj = TFConv1D(nx, n_state, name='c_proj')
self.act = gelu
self.dropout = nn.Dropout(config.resid_pdrop)
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
def forward(self, x):
@tf.function
def call(self, x, training=False):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
if training:
h2 = self.dropout(h2)
return h2
class TFBlock(tf.keras.layers.Layer):
def __init__(self, n_ctx, config, scale=False, **kwargs):
super(TFBlock, self).__init__(**kwargs)
nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_1')
self.attn = TFAttention(nx, n_ctx, config, scale, name='attn')
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
self.mlp = TFMLP(4 * nx, config, name='mlp')
def forward(self, x, layer_past=None, head_mask=None):
output_attn = self.attn(self.ln_1(x), layer_past=layer_past, head_mask=head_mask)
@tf.function
def call(self, x, layer_past=None, head_mask=None, training=False):
output_attn = self.attn(self.ln_1(x),
layer_past=layer_past,
head_mask=head_mask,
training=training)
a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a
m = self.mlp(self.ln_2(x))
m = self.mlp(self.ln_2(x), training=training)
x = x + m
outputs = [x] + output_attn[1:]
return outputs # x, present, (attentions)
class TFGPT2Embeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config, **kwargs):
super(TFGPT2Embeddings, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
def build(self, input_shape):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self.weight = self.add_weight(
"weight",
shape=[self.vocab_size, self.n_embed],
initializer=tf.random_normal_initializer(
mean=0., stddev=self.n_embed**-0.5))
super(TFBertEmbeddings, self).build(input_shape)
@tf.function
def call(self, inputs, mode="embedding", training=False):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(inputs, training=training)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids):
"""Applies embedding based on inputs tensor."""
return tf.gather(self.weight, input_ids)
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
batch_size = tf.shape(inputs)[0]
length = tf.shape(inputs)[1]
x = tf.reshape(inputs, [-1, self.n_embed])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super(TFGPT2MainLayer, self).__init__(config, *inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
self.wte = TFGPT2Embeddings(config, name='wte')
self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=Truename='h_{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise NotImplementedError
@tf.function
def call(self, inputs, training=False):
input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
assert len(inputs) <= 5, "Too many inputs."
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
class GPT2PreTrainedModel(PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
presents = ()
all_attentions = []
all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, layer_past, head_mask[i])
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions)
class TFGPT2PreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
......@@ -260,22 +445,6 @@ class GPT2PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
`Language Models are Unsupervised Multitask Learners`_
......@@ -283,14 +452,26 @@ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
corpus of ~40 GB of text data.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
.. _`Language Models are Unsupervised Multitask Learners`:
https://openai.com/blog/better-language-models/
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
.. _`tf.keras.Model`:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
Important note on the model inputs:
The inputs of the TF 2.0 models are slightly different from the PyTorch ones since
TF 2.0 Keras doesn't accept named arguments with defaults values for input Tensor.
More precisely, input Tensors are gathered in the first arguments of the model call function: `model(inputs)`.
There are three possibilities to gather and feed the inputs to the model:
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
Parameters:
config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
......@@ -325,7 +506,7 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
class GPT2Model(GPT2PreTrainedModel):
class TFGPT2Model(TFGPT2PreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
......@@ -345,37 +526,72 @@ class GPT2Model(GPT2PreTrainedModel):
Examples::
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
model = TFGPT2Model.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(GPT2Model, self).__init__(config)
def __init__(self, config, *inputs, **kwargs):
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=Truename='h_{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
self.init_weights()
def build(self, input_shape):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
with tf.name_scope("wte"):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self.wte = self.add_weight(
"weight",
shape=[self.vocab_size, self.n_embed],
initializer=tf.random_normal_initializer(
mean=0., stddev=self.n_embed**-0.5))
super(TFGPT2Model, self).build(input_shape)
def _resize_token_embeddings(self, new_num_tokens):
self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
return self.wte
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
raise NotImplementedError
@tf.function
def call(self, inputs, training=False):
input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs
attention_mask, head_mask, position_ids, token_type_ids = None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else None
token_type_ids = inputs[2] if len(inputs) > 2 else None
position_ids = inputs[3] if len(inputs) > 3 else None
head_mask = inputs[4] if len(inputs) > 4 else None
assert len(inputs) <= 5, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
position_ids = inputs.get('position_ids', None)
head_mask = inputs.get('head_mask', None)
assert len(inputs) <= 5, "Too many inputs."
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
if past is None:
past_length = 0
past = [None] * len(self.h)
......
......@@ -52,7 +52,7 @@ class TFPreTrainedModel(tf.keras.Model):
base_model_prefix = ""
def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__()
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
......@@ -257,11 +257,11 @@ class TFPreTrainedModel(tf.keras.Model):
return model
class TFConv1D(tf.keras.layers.Layer):
def __init__(self, nf, nx):
def __init__(self, nf, nx, *inputs, **kwargs):
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
super(TFConv1D, self).__init__()
super(TFConv1D, self).__init__(*inputs, **kwargs)
self.nf = nf
self.nx = nx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment