Unverified Commit 7daacf00 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Merge pull request #1695 from huggingface/models_inputs_embeds

model forwards can take an inputs_embeds param
parents a44f112f 00337e96
......@@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
......@@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.attentions[layer].prune_heads(heads)
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None): # removed: src_enc=None, src_len=None
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None): # removed: src_enc=None, src_len=None
if input_ids is not None:
bs, slen = input_ids.size()
else:
bs, slen = inputs_embeds.size()[:-1]
if lengths is None:
if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
else:
lengths = torch.LongTensor([slen]*bs)
# mask = input_ids != self.pad_index
# check inputs
bs, slen = input_ids.size()
assert lengths.size(0) == bs
assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
......@@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel):
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids
if position_ids is None:
position_ids = input_ids.new((slen,)).long()
position_ids = torch.arange(slen, out=position_ids).unsqueeze(0)
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else:
assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1)
......@@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel):
head_mask = [None] * self.n_layers
# do not recompute cached elements
if cache is not None:
if cache is not None and input_ids is not None:
_slen = slen - cache['slen']
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
......@@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel):
attn_mask = attn_mask[:, -_slen:]
# embeddings
tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
......@@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
......@@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
outputs = self.pred_layer(output, labels)
......@@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
......@@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
......@@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None):
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
......@@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = transformer_outputs[0]
......@@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None,
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None,
is_impossible=None, cls_index=None, p_mask=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
......@@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
......
......@@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
......@@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None):
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = input_ids.transpose(0, 1).contiguous()
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
elif inputs_embeds is not None:
inputs_embeds.transpose(0, 1).contiguous()
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
......@@ -777,6 +791,9 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k)
if target_mapping is not None:
......@@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
......@@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
logits = self.lm_loss(transformer_outputs[0])
......@@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
......@@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
output = self.sequence_summary(output)
......@@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
......@@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
def forward(self, input_ids=None, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None):
labels=None, head_mask=None, inputs_embeds=None):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
......@@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask)
head_mask=head_mask, inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
......@@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
......@@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
......@@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
......@@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
......@@ -525,6 +525,19 @@ class CommonTestCases:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
wte = model.get_input_embeddings()
inputs_dict["inputs_embeds"] = wte(input_ids)
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment