Unverified Commit 7daacf00 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Merge pull request #1695 from huggingface/models_inputs_embeds

model forwards can take an inputs_embeds param
parents a44f112f 00337e96
...@@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r""" ...@@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel): ...@@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.attentions[layer].prune_heads(heads) self.attentions[layer].prune_heads(heads)
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None): # removed: src_enc=None, src_len=None lengths=None, cache=None, head_mask=None, inputs_embeds=None): # removed: src_enc=None, src_len=None
if input_ids is not None:
bs, slen = input_ids.size()
else:
bs, slen = inputs_embeds.size()[:-1]
if lengths is None: if lengths is None:
lengths = (input_ids != self.pad_index).sum(dim=1).long() if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
else:
lengths = torch.LongTensor([slen]*bs)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
bs, slen = input_ids.size()
assert lengths.size(0) == bs assert lengths.size(0) == bs
assert lengths.max().item() <= slen assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
...@@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel): ...@@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel):
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids # position_ids
if position_ids is None: if position_ids is None:
position_ids = input_ids.new((slen,)).long() position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = torch.arange(slen, out=position_ids).unsqueeze(0) position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else: else:
assert position_ids.size() == (bs, slen) # (slen, bs) assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
...@@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel):
head_mask = [None] * self.n_layers head_mask = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None: if cache is not None and input_ids is not None:
_slen = slen - cache['slen'] _slen = slen - cache['slen']
input_ids = input_ids[:, -_slen:] input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:] position_ids = position_ids[:, -_slen:]
...@@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel): ...@@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel):
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
tensor = self.embeddings(input_ids) if inputs_embeds is None:
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
if langs is not None and self.use_lang_emb: if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
...@@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.pred_layer.proj return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None): lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
lengths=lengths, lengths=lengths,
cache=cache, cache=cache,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output, labels) outputs = self.pred_layer(output, labels)
...@@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None): lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
lengths=lengths, lengths=lengths,
cache=cache, cache=cache,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
...@@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): ...@@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None): lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): ...@@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
lengths=lengths, lengths=lengths,
cache=cache, cache=cache,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
...@@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None,
is_impossible=None, cls_index=None, p_mask=None): is_impossible=None, cls_index=None, p_mask=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
lengths=lengths, lengths=lengths,
cache=cache, cache=cache,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0] output = transformer_outputs[0]
......
...@@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters())) pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb return pos_emb
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None): token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
input_ids = input_ids.transpose(0, 1).contiguous() if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = input_ids.transpose(0, 1).contiguous()
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
elif inputs_embeds is not None:
inputs_embeds.transpose(0, 1).contiguous()
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0 mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
...@@ -777,7 +791,10 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -777,7 +791,10 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask = None non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states ##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(input_ids) if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k) output_h = self.dropout(word_emb_k)
if target_mapping is not None: if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1) word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
...@@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None): token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target_mapping=target_mapping, target_mapping=target_mapping,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
logits = self.lm_loss(transformer_outputs[0]) logits = self.lm_loss(transformer_outputs[0])
...@@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None): token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
target_mapping=target_mapping, target_mapping=target_mapping,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
...@@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
...@@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids=None, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None): labels=None, head_mask=None, inputs_embeds=None):
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
...@@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids, transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
input_mask=flat_input_mask, attention_mask=flat_attention_mask, input_mask=flat_input_mask, attention_mask=flat_attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask) head_mask=head_mask, inputs_embeds=inputs_embeds)
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None): start_positions=None, end_positions=None):
outputs = self.transformer(input_ids, outputs = self.transformer(input_ids,
...@@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
target_mapping=target_mapping, target_mapping=target_mapping,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,): start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
target_mapping=target_mapping, target_mapping=target_mapping,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask) start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
...@@ -525,6 +525,19 @@ class CommonTestCases: ...@@ -525,6 +525,19 @@ class CommonTestCases:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
wte = model.get_input_embeddings()
inputs_dict["inputs_embeds"] = wte(input_ids)
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester): class GPTModelTester(CommonModelTester):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment