Commit 07009830 authored by Rémi Louf's avatar Rémi Louf
Browse files

Add BertDecoderModel and Bert2Bert classes

I am not sure what happens when the class is initialized with the
pretrained weights.
parent 75feacf1
......@@ -788,6 +788,110 @@ class BertModel(BertPreTrainedModel):
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
The model follows the general transformer decoder architecture.""",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
class BertDecoderModel(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertDecoderModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.decoder = BertDecoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
decoder_outputs = self.decoder(embedding_output,
encoder_outputs,
extended_attention_mask,
head_mask=head_mask)
sequence_output = decoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING,
......@@ -1312,13 +1416,20 @@ class BertForQuestionAnswering(BertPreTrainedModel):
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings("Bert encoder-decoder model",
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
class Bert2Bert(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......@@ -1328,45 +1439,33 @@ class Bert2Bert(BertPreTrainedModel):
outputs = model(input)
output_text = tokenize.decode(outputs[0])
print(output_text)
References::
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
"""
def __init__(self, config):
super(Bert2Bert, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.decoder = BertDecoder(config)
def forward(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(inputs)
if token_type_ids is None:
token_type_ids = torch.zeros_like(inputs)
self.encoder = BertModel(config)
self.decoder = BertDecoderModel(config)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
self.init_weights()
embedding_output = self.embeddings(inputs, position_ids=position_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
decoder_outputs = self.decoder(embedding_output,
encoder_outputs[0],
extended_attention_mask,
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
sequence_output = decoder_outputs[0]
pooled_output = self.pooler(sequence_output)
encoder_output = encoder_outputs[0]
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range)
decoder_outputs = self.decoder(decoder_input,
encoder_output,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
return decoder_outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment