Unverified Commit a36f981d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into fix-ctrl-past

parents 151e4ab4 5afca00b
...@@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) ...@@ -39,6 +39,7 @@ logger = logging.getLogger(__name__)
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",} "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
...@@ -297,7 +298,8 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -297,7 +298,8 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -313,6 +315,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -313,6 +315,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -325,7 +331,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -325,7 +331,8 @@ class GPT2Model(GPT2PreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -370,9 +377,17 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -370,9 +377,17 @@ class GPT2Model(GPT2PreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None: if position_ids is not None:
...@@ -384,8 +399,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -384,8 +399,9 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
past_length = past[0][0].size(-2) past_length = past[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -419,6 +435,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -419,6 +435,7 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
head_mask = [None] * self.config.n_layer head_mask = [None] * self.config.n_layer
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
...@@ -488,7 +505,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -488,7 +505,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -520,14 +538,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -520,14 +538,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -579,7 +598,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -579,7 +598,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
**past**: **past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks). that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -623,14 +643,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -623,14 +643,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -322,6 +322,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs: ...@@ -322,6 +322,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
...@@ -373,14 +377,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -373,14 +377,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None: if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings # Code is different from when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special device = input_ids.device if input_ids is not None else inputs_embeds.device
# end = start + input_ids.size(-1) position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -413,10 +425,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -413,10 +425,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
else: else:
head_mask = [None] * self.config.n_layer head_mask = [None] * self.config.n_layer
input_shape = input_ids.size() if inputs_embeds is None:
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.tokens_embed(input_ids) inputs_embeds = self.tokens_embed(input_ids)
position_embeds = self.positions_embed(position_ids) position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
...@@ -495,13 +504,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -495,13 +504,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -587,13 +597,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -587,13 +597,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
mc_token_ids=None, lm_labels=None, mc_labels=None): mc_token_ids=None, lm_labels=None, mc_labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
......
...@@ -35,6 +35,8 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -35,6 +35,8 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin", 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin", 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin", 'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-pytorch_model.bin",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-pytorch_model.bin",
} }
class RobertaEmbeddings(BertEmbeddings): class RobertaEmbeddings(BertEmbeddings):
...@@ -48,16 +50,24 @@ class RobertaEmbeddings(BertEmbeddings): ...@@ -48,16 +50,24 @@ class RobertaEmbeddings(BertEmbeddings):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)
def forward(self, input_ids, token_type_ids=None, position_ids=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
seq_length = input_ids.size(1) if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None: if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored. # Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions` # cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand(input_shape)
return super(RobertaEmbeddings, self).forward(input_ids, return super(RobertaEmbeddings, self).forward(input_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids) position_ids=position_ids,
inputs_embeds=inputs_embeds)
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
...@@ -126,6 +136,10 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -126,6 +136,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -222,13 +236,14 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -222,13 +236,14 @@ class RobertaForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None): masked_lm_labels=None):
outputs = self.roberta(input_ids, outputs = self.roberta(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
...@@ -309,13 +324,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -309,13 +324,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config) self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None): labels=None):
outputs = self.roberta(input_ids, outputs = self.roberta(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -372,6 +388,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -372,6 +388,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
...@@ -415,8 +435,8 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -415,8 +435,8 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None, inputs_embeds=None):
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
...@@ -487,14 +507,15 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -487,14 +507,15 @@ class RobertaForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.roberta(input_ids, outputs = self.roberta(input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0] sequence_output = outputs[0]
......
This diff is collapsed.
...@@ -109,6 +109,9 @@ class TFAutoModel(object): ...@@ -109,6 +109,9 @@ class TFAutoModel(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object): ...@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object): ...@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
def _embedding(self, inputs, training=False): def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
seq_length = tf.shape(input_ids)[1] if input_ids is not None:
input_shape = tf.shape(input_ids)
else:
input_shape = tf.shape(inputs_embeds)[:-1]
seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(input_shape, 0)
words_embeddings = tf.gather(self.word_embeddings, input_ids) if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
...@@ -460,6 +466,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -460,6 +466,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.encoder = TFBertEncoder(config, name='encoder') self.encoder = TFBertEncoder(config, name='encoder')
self.pooler = TFBertPooler(config, name='pooler') self.pooler = TFBertPooler(config, name='pooler')
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -470,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -470,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1) attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
...@@ -520,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -520,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training) embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training) encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
...@@ -616,6 +636,10 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -616,6 +636,10 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
...@@ -698,6 +722,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -698,6 +722,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.nsp = TFBertNSPHead(config, name='nsp___cls') self.nsp = TFBertNSPHead(config, name='nsp___cls')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def get_output_embeddings(self):
return self.bert.embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -743,6 +770,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -743,6 +770,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def get_output_embeddings(self):
return self.bert.embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -888,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -888,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name='classifier') name='classifier')
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None:
num_choices = tf.shape(input_ids)[1] num_choices = tf.shape(input_ids)[1]
seq_length = tf.shape(input_ids)[2] seq_length = tf.shape(input_ids)[2]
else:
num_choices = tf.shape(inputs_embeds)[1]
seq_length = tf.shape(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask] flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
outputs = self.bert(flat_inputs, training=training) outputs = self.bert(flat_inputs, training=training)
......
...@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
name='h_._{}'.format(i)) for i in range(config.n_layer)] name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm") self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
def get_input_embeddings(self):
return self.w
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -201,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -201,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -209,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -209,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', past) past = inputs.get('past', past)
...@@ -217,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -217,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -230,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -230,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(past[0][0])[-2]
if position_ids is None: if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1]) position_ids = tf.tile(position_ids, [input_shape[0], 1])
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -270,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -270,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_embeds = 0 token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs_embeds is None:
inputs_embeds = self.w(input_ids, mode='embedding') inputs_embeds = self.w(input_ids, mode='embedding')
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
...@@ -374,6 +386,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs: ...@@ -374,6 +386,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.", @add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
...@@ -476,6 +492,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel): ...@@ -476,6 +492,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.input_embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range)) initializer=get_initializer(self.initializer_range))
super(TFEmbeddings, self).build(input_shape) super(TFEmbeddings, self).build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
""" """
if mode == "embedding": if mode == "embedding":
return self._embedding(inputs, training=training) return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training)
elif mode == "linear": elif mode == "linear":
return self._linear(inputs) return self._linear(inputs)
else: else:
raise ValueError("mode {} is not valid.".format(mode)) raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, inputs, training=False): def _embedding(self, inputs, inputs_embeds=None, training=False):
""" """
Parameters Parameters
---------- ----------
...@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer):
else: else:
input_ids, position_ids = inputs input_ids, position_ids = inputs
if input_ids is not None:
seq_length = tf.shape(input_ids)[1] seq_length = tf.shape(input_ids)[1]
else:
seq_length = tf.shape(inputs_embeds)[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
word_embeddings = tf.gather(self.word_embeddings, input_ids) if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim) embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim)
return embeddings return embeddings
...@@ -398,28 +403,42 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -398,28 +403,42 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
self.transformer = TFTransformer(config, name="transformer") # Encoder self.transformer = TFTransformer(config, name="transformer") # Encoder
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
assert len(inputs) <= 3, "Too many inputs." inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 3, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 4, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length) attention_mask = tf.ones(input_shape) # (bs, seq_length)
attention_mask = tf.cast(attention_mask, dtype=tf.float32) attention_mask = tf.cast(attention_mask, dtype=tf.float32)
# Prepare head mask if needed # Prepare head mask if needed
...@@ -432,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -432,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
else: else:
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim) embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training) tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training)
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
...@@ -508,6 +527,10 @@ DISTILBERT_INPUTS_DOCSTRING = r""" ...@@ -508,6 +527,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
...@@ -609,6 +632,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): ...@@ -609,6 +632,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
def get_output_embeddings(self):
return self.vocab_projector.input_embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
distilbert_output = self.distilbert(inputs, **kwargs) distilbert_output = self.distilbert(inputs, **kwargs)
......
...@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
name='h_._{}'.format(i)) for i in range(config.n_layer)] name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f') self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def get_input_embeddings(self):
return self.wte
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -228,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -228,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -236,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -236,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', past) past = inputs.get('past', past)
...@@ -244,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -244,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if past is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) past = [None] * len(self.h)
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(past[0][0])[-2]
if position_ids is None: if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
if attention_mask is not None: if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -286,10 +301,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -286,10 +301,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids, mode='embedding') inputs_embeds = self.wte(input_ids, mode='embedding')
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
if token_type_ids is not None: if token_type_ids is not None:
...@@ -408,6 +422,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs: ...@@ -408,6 +422,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
...@@ -486,6 +504,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): ...@@ -486,6 +504,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
def get_output_embeddings(self):
return self.transformer.wte
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs) transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -556,7 +577,10 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -556,7 +577,10 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False): def get_output_embeddings(self):
return self.transformer.wte
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -564,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -564,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs." mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
past = inputs.get('past', past) past = inputs.get('past', past)
...@@ -573,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -573,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids) mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
assert len(inputs) <= 7, "Too many inputs." assert len(inputs) <= 8, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None:
input_shapes = shape_list(input_ids) input_shapes = shape_list(input_ids)
else:
input_shapes = shape_list(inputs_embeds)[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask] flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
This diff is collapsed.
...@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
def _embedding(self, inputs, training=False): def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
if input_ids is not None:
seq_length = tf.shape(input_ids)[1] seq_length = tf.shape(input_ids)[1]
else:
seq_length = tf.shape(inputs_embeds)[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids], training=training) return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
class TFRobertaMainLayer(TFBertMainLayer): class TFRobertaMainLayer(TFBertMainLayer):
...@@ -65,6 +69,9 @@ class TFRobertaMainLayer(TFBertMainLayer): ...@@ -65,6 +69,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
super(TFRobertaMainLayer, self).__init__(config, **kwargs) super(TFRobertaMainLayer, self).__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name='embeddings') self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
def get_input_embeddings(self):
return self.embeddings
class TFRobertaPreTrainedModel(TFPreTrainedModel): class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
...@@ -157,6 +164,10 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -157,6 +164,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.", @add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
...@@ -276,6 +287,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel): ...@@ -276,6 +287,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self.roberta = TFRobertaMainLayer(config, name="roberta") self.roberta = TFRobertaMainLayer(config, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.lm_head.decoder
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs) outputs = self.roberta(inputs, **kwargs)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment