Commit c536c2a4 authored by LysandreJik's avatar LysandreJik Committed by Lysandre Debut
Browse files

ALBERT Input Embeds

parent f873b55e
...@@ -433,6 +433,12 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -433,6 +433,12 @@ class AlbertModel(AlbertPreTrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
...@@ -457,12 +463,24 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -457,12 +463,24 @@ class AlbertModel(AlbertPreTrainedModel):
inner_group_idx = int(layer - group_idx * self.config.inner_group_num) inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads) self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
...@@ -477,7 +495,8 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -477,7 +495,8 @@ class AlbertModel(AlbertPreTrainedModel):
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(embedding_output, encoder_outputs = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask=head_mask) head_mask=head_mask)
...@@ -549,9 +568,19 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -549,9 +568,19 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
self._tie_or_clone_weights(self.predictions.decoder, self._tie_or_clone_weights(self.predictions.decoder,
self.albert.embeddings.word_embeddings) self.albert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def get_output_embeddings(self):
masked_lm_labels=None): return self.predictions.decoder
outputs = self.albert(input_ids, attention_mask, token_type_ids, position_ids, head_mask)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, inputs_embeds=None):
outputs = self.albert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds
)
sequence_outputs = outputs[0] sequence_outputs = outputs[0]
prediction_scores = self.predictions(sequence_outputs) prediction_scores = self.predictions(sequence_outputs)
...@@ -609,14 +638,17 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -609,14 +638,17 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None): position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.albert(input_ids, outputs = self.albert(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask) position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds
)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -692,14 +724,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -692,14 +724,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
start_positions=None, end_positions=None): inputs_embeds=None, start_positions=None, end_positions=None):
outputs = self.albert(input_ids, outputs = self.albert(
attention_mask=attention_mask, input_ids=input_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask,
position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask) position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds
)
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -107,19 +107,25 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -107,19 +107,25 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
def _embedding(self, inputs, training=False): def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor.""" """Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
seq_length = tf.shape(input_ids)[1] if input_ids is not None:
input_shape = tf.shape(input_ids)
else:
input_shape = tf.shape(inputs_embeds)[:-1]
seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(input_shape, 0)
words_embeddings = tf.gather(self.word_embeddings, input_ids) if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings) embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training) embeddings = self.dropout(embeddings, training=training)
return embeddings return embeddings
...@@ -603,6 +609,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -603,6 +609,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
self.pooler = tf.keras.layers.Dense(config.hidden_size, kernel_initializer=get_initializer( self.pooler = tf.keras.layers.Dense(config.hidden_size, kernel_initializer=get_initializer(
config.initializer_range), activation='tanh', name='pooler') config.initializer_range), activation='tanh', name='pooler')
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError raise NotImplementedError
...@@ -613,28 +622,39 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -613,28 +622,39 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
""" """
raise NotImplementedError raise NotImplementedError
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get('input_ids')
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 5, "Too many inputs." inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1) attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
...@@ -664,7 +684,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -664,7 +684,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings( embedding_output = self.embeddings(
[input_ids, position_ids, token_type_ids], training=training) [input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask], training=training) [embedding_output, extended_attention_mask, head_mask], training=training)
...@@ -712,6 +732,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): ...@@ -712,6 +732,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
self.predictions = TFAlbertMLMHead( self.predictions = TFAlbertMLMHead(
config, self.albert.embeddings, name='predictions') config, self.albert.embeddings, name='predictions')
def get_output_embeddings(self):
return self.albert.embeddings
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.albert(inputs, **kwargs) outputs = self.albert(inputs, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment