"vscode:/vscode.git/clone" did not exist on "c581ff51a8c22c10ae7e98abff318ea09613f6a0"
Unverified Commit 7daacf00 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Merge pull request #1695 from huggingface/models_inputs_embeds

model forwards can take an inputs_embeds param
parents a44f112f 00337e96
...@@ -255,6 +255,10 @@ XXX_INPUTS_DOCSTRING = r""" ...@@ -255,6 +255,10 @@ XXX_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -238,6 +238,10 @@ XXX_INPUTS_DOCSTRING = r""" ...@@ -238,6 +238,10 @@ XXX_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -295,7 +299,7 @@ class XxxModel(XxxPreTrainedModel): ...@@ -295,7 +299,7 @@ class XxxModel(XxxPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
if token_type_ids is None: if token_type_ids is None:
...@@ -449,14 +453,15 @@ class XxxForSequenceClassification(XxxPreTrainedModel): ...@@ -449,14 +453,15 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids, outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -520,14 +525,15 @@ class XxxForTokenClassification(XxxPreTrainedModel): ...@@ -520,14 +525,15 @@ class XxxForTokenClassification(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids, outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -603,14 +609,15 @@ class XxxForQuestionAnswering(XxxPreTrainedModel): ...@@ -603,14 +609,15 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None): start_positions=None, end_positions=None):
outputs = self.transformer(input_ids, outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -158,19 +158,26 @@ class BertEmbeddings(nn.Module): ...@@ -158,19 +158,26 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
seq_length = input_ids.size(1) if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) token_type_ids = torch.zeros(input_shape, dtype=torch.long)
words_embeddings = self.word_embeddings(input_ids) if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -550,6 +557,10 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -550,6 +557,10 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``: **encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
is configured as a decoder. is configured as a decoder.
...@@ -615,8 +626,8 @@ class BertModel(BertPreTrainedModel): ...@@ -615,8 +626,8 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
""" Forward pass on the Model. """ Forward pass on the Model.
The model can behave as an encoder (with only self-attention) as well The model can behave as an encoder (with only self-attention) as well
...@@ -632,12 +643,23 @@ class BertModel(BertPreTrainedModel): ...@@ -632,12 +643,23 @@ class BertModel(BertPreTrainedModel):
https://arxiv.org/abs/1706.03762 https://arxiv.org/abs/1706.03762
""" """
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones(input_shape)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones_like(input_ids) encoder_attention_mask = torch.ones(input_shape)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) token_type_ids = torch.zeros(input_shape, dtype=torch.long)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
...@@ -649,8 +671,8 @@ class BertModel(BertPreTrainedModel): ...@@ -649,8 +671,8 @@ class BertModel(BertPreTrainedModel):
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
if self.config.is_decoder: if self.config.is_decoder:
batch_size, seq_length = input_ids.size() batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=input_ids.device) seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
...@@ -689,7 +711,7 @@ class BertModel(BertPreTrainedModel): ...@@ -689,7 +711,7 @@ class BertModel(BertPreTrainedModel):
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(embedding_output, encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -754,14 +776,15 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -754,14 +776,15 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, next_sentence_label=None): masked_lm_labels=None, next_sentence_label=None):
outputs = self.bert(input_ids, outputs = self.bert(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
...@@ -829,7 +852,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -829,7 +852,7 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
outputs = self.bert(input_ids, outputs = self.bert(input_ids,
...@@ -837,6 +860,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -837,6 +860,7 @@ class BertForMaskedLM(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask) encoder_attention_mask=encoder_attention_mask)
...@@ -908,14 +932,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -908,14 +932,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
next_sentence_label=None): next_sentence_label=None):
outputs = self.bert(input_ids, outputs = self.bert(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -975,14 +1000,15 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -975,14 +1000,15 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids, outputs = self.bert(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1049,8 +1075,8 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1049,8 +1075,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) input_ids = input_ids.view(-1, input_ids.size(-1))
...@@ -1062,7 +1088,8 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1062,7 +1088,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1123,14 +1150,15 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1123,14 +1150,15 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids, outputs = self.bert(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1207,14 +1235,15 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1207,14 +1235,15 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None): start_positions=None, end_positions=None):
outputs = self.bert(input_ids, outputs = self.bert(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -236,6 +236,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs: ...@@ -236,6 +236,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -302,17 +306,26 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -302,17 +306,26 @@ class CTRLModel(CTRLPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
input_shape = input_ids.size() if input_ids is not None and inputs_embeds is not None:
input_ids = input_ids.view(-1, input_shape[-1]) raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if past is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) past = [None] * len(self.h)
else: else:
past_length = past[0][0].size(-2) past_length = past[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -354,9 +367,10 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -354,9 +367,10 @@ class CTRLModel(CTRLPreTrainedModel):
token_type_embeds = 0 token_type_embeds = 0
position_ids = position_ids.view(-1, input_shape[-1]) position_ids = position_ids.view(-1, input_shape[-1])
inputs_embeds = self.w(input_ids) if inputs_embeds is None:
inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[-1] seq_len = input_shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device) mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
inputs_embeds *= np.sqrt(self.d_model_size) inputs_embeds *= np.sqrt(self.d_model_size)
...@@ -455,14 +469,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -455,14 +469,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -387,6 +387,10 @@ DISTILBERT_INPUTS_DOCSTRING = r""" ...@@ -387,6 +387,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
...@@ -436,9 +440,18 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -436,9 +440,18 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads) self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self, def forward(self,
input_ids, attention_mask=None, head_mask=None): input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length) attention_mask = torch.ones(input_shape) # (bs, seq_length)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -455,8 +468,9 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -455,8 +468,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim) if inputs_embeds is None:
tfmr_output = self.transformer(x=embedding_output, inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=inputs_embeds,
attn_mask=attention_mask, attn_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
hidden_state = tfmr_output[0] hidden_state = tfmr_output[0]
...@@ -514,10 +528,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -514,10 +528,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.vocab_projector return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None): def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids, dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
...@@ -578,10 +593,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -578,10 +593,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None): def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
distilbert_output = self.distilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
...@@ -652,10 +668,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -652,10 +668,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None): def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
distilbert_output = self.distilbert(input_ids=input_ids, distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim) hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
......
...@@ -313,6 +313,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -313,6 +313,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -370,9 +374,17 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -370,9 +374,17 @@ class GPT2Model(GPT2PreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
input_shape = input_ids.size() if input_ids is not None and inputs_embeds is not None:
input_ids = input_ids.view(-1, input_shape[-1]) raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None: if position_ids is not None:
...@@ -384,8 +396,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -384,8 +396,9 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
past_length = past[0][0].size(-2) past_length = past[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -419,7 +432,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -419,7 +432,8 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
head_mask = [None] * self.config.n_layer head_mask = [None] * self.config.n_layer
inputs_embeds = self.wte(input_ids) if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids) token_type_embeds = self.wte(token_type_ids)
...@@ -520,14 +534,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -520,14 +534,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -623,14 +638,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -623,14 +638,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -322,6 +322,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs: ...@@ -322,6 +322,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
...@@ -373,14 +377,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -373,14 +377,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None: if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings # Code is different from when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special device = input_ids.device if input_ids is not None else inputs_embeds.device
# end = start + input_ids.size(-1) position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -413,11 +425,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -413,11 +425,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
else: else:
head_mask = [None] * self.config.n_layer head_mask = [None] * self.config.n_layer
input_shape = input_ids.size() if inputs_embeds is None:
input_ids = input_ids.view(-1, input_ids.size(-1)) inputs_embeds = self.tokens_embed(input_ids)
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.tokens_embed(input_ids)
position_embeds = self.positions_embed(position_ids) position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
...@@ -495,13 +504,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -495,13 +504,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -587,13 +597,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -587,13 +597,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
......
...@@ -48,16 +48,24 @@ class RobertaEmbeddings(BertEmbeddings): ...@@ -48,16 +48,24 @@ class RobertaEmbeddings(BertEmbeddings):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)
def forward(self, input_ids, token_type_ids=None, position_ids=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
seq_length = input_ids.size(1) if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None: if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored. # Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions` # cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand(input_shape)
return super(RobertaEmbeddings, self).forward(input_ids, return super(RobertaEmbeddings, self).forward(input_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids) position_ids=position_ids,
inputs_embeds=inputs_embeds)
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
...@@ -126,6 +134,10 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -126,6 +134,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -222,13 +234,14 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -222,13 +234,14 @@ class RobertaForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None): masked_lm_labels=None):
outputs = self.roberta(input_ids, outputs = self.roberta(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
...@@ -309,13 +322,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -309,13 +322,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config) self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
outputs = self.roberta(input_ids, outputs = self.roberta(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -372,6 +386,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -372,6 +386,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
...@@ -415,8 +433,8 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -415,8 +433,8 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None, inputs_embeds=None):
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
...@@ -487,14 +505,15 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -487,14 +505,15 @@ class RobertaForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.roberta(input_ids, outputs = self.roberta(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -616,6 +616,10 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -616,6 +616,10 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -374,6 +374,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs: ...@@ -374,6 +374,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
......
...@@ -508,6 +508,10 @@ DISTILBERT_INPUTS_DOCSTRING = r""" ...@@ -508,6 +508,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -408,6 +408,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -408,6 +408,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -389,6 +389,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs: ...@@ -389,6 +389,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
......
...@@ -157,6 +157,10 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -157,6 +157,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -626,6 +626,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r""" ...@@ -626,6 +626,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models. r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
......
...@@ -530,6 +530,10 @@ XLM_INPUTS_DOCSTRING = r""" ...@@ -530,6 +530,10 @@ XLM_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -762,6 +762,10 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -762,6 +762,10 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
......
...@@ -553,6 +553,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r""" ...@@ -553,6 +553,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -657,12 +661,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -657,12 +661,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
logger.info("Head pruning is not implemented for Transformer-XL model") logger.info("Head pruning is not implemented for Transformer-XL model")
pass pass
def init_mems(self, data): def init_mems(self, bsz):
if self.mem_len > 0: if self.mem_len > 0:
mems = [] mems = []
param = next(self.parameters()) param = next(self.parameters())
for i in range(self.n_layer): for i in range(self.n_layer):
empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model, empty = torch.zeros(self.mem_len, bsz, self.config.d_model,
dtype=param.dtype, device=param.device) dtype=param.dtype, device=param.device)
mems.append(empty) mems.append(empty)
...@@ -693,15 +697,22 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -693,15 +697,22 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems return new_mems
def forward(self, input_ids, mems=None, head_mask=None): def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz] # so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = input_ids.transpose(0, 1).contiguous() if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = input_ids.transpose(0, 1).contiguous()
qlen, bsz = input_ids.size()
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if mems is None: if mems is None:
mems = self.init_mems(input_ids) mems = self.init_mems(bsz)
qlen, bsz = input_ids.size()
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -718,7 +729,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -718,7 +729,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else: else:
head_mask = [None] * self.n_layer head_mask = [None] * self.n_layer
word_emb = self.word_emb(input_ids) if inputs_embeds is not None:
word_emb = inputs_embeds
else:
word_emb = self.word_emb(input_ids)
mlen = mems[0].size(0) if mems is not None else 0 mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
...@@ -860,14 +874,18 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -860,14 +874,18 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) self.transformer.reset_length(tgt_len, ext_len, mem_len)
def init_mems(self, data): def init_mems(self, bsz):
return self.transformer.init_mems(data) return self.transformer.init_mems(bsz)
def forward(self, input_ids, mems=None, head_mask=None, labels=None): def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
bsz = input_ids.size(0) if input_ids is not None:
tgt_len = input_ids.size(1) bsz, tgt_len = input_ids.size(0), input_ids.size(1)
elif inputs_embeds is not None:
bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask) transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds)
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
......
...@@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module): ...@@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module):
r""" Base class for all models. r""" Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes): Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment