Unverified Commit 4d10ffd5 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`FSMT`] Make it compatible with `xxxForConditionalGeneration` models (#20825)



* add `get_encoder` and `get_decoder`

* add additional kwargs support

* fix condition

* add better checks

* better checks

* fix embed positions

* better test to consider padding

* fix debug statement

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* add arguments on docstring
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2222740f
......@@ -272,6 +272,18 @@ FSMT_INPUTS_DOCSTRING = r"""
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
......@@ -470,6 +482,7 @@ class FSMTEncoder(nn.Module):
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: torch.Tensor = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
......@@ -480,6 +493,8 @@ class FSMTEncoder(nn.Module):
input_ids (`torch.LongTensor`): tokens in the source language of shape
*(batch, src_len)*
attention_mask (`torch.LongTensor`): indicating which indices are padding tokens
inputs_embeds (`torch.FloatTensor`):
embedding vectors of shape *(batch, src_len, embed_dim)*
head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
......@@ -499,8 +514,24 @@ class FSMTEncoder(nn.Module):
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds * self.embed_scale
# We assume zeros hidden states correspond to padding tokens
# and create `position_ids` where inputs_embeds[:, :, 0] == 0
position_ids = inputs_embeds[:, :, 0].masked_fill(
inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
)
embed_pos = self.embed_positions(position_ids)
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
x = inputs_embeds + embed_pos
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
......@@ -675,6 +706,7 @@ class FSMTDecoder(nn.Module):
decoder_padding_mask: torch.Tensor,
decoder_causal_mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
......@@ -717,15 +749,26 @@ class FSMTDecoder(nn.Module):
if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
# embed positions
positions = self.embed_positions(input_ids) # , use_cache=use_cache)
positions = self.embed_positions(input_ids)
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
# assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids) * self.embed_scale
elif inputs_embeds is not None:
# We assume zeros hidden states correspond to padding tokens
# and create `position_ids` where inputs_embeds[:, :, 0] == 0
position_ids = inputs_embeds[:, :, 0].masked_fill(
inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
)
positions = self.embed_positions(position_ids)
x = inputs_embeds * self.embed_scale
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
x += positions
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
......@@ -1007,6 +1050,12 @@ class FSMTModel(PretrainedFSMTModel):
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1028,6 +1077,8 @@ class FSMTModel(PretrainedFSMTModel):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
if decoder_input_ids is None:
......@@ -1041,7 +1092,7 @@ class FSMTModel(PretrainedFSMTModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# make masks if user doesn't supply
if not use_cache:
if not use_cache and input_ids is not None:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
self.config,
input_ids,
......@@ -1052,12 +1103,14 @@ class FSMTModel(PretrainedFSMTModel):
else:
decoder_padding_mask, causal_mask = None, None
assert decoder_input_ids is not None
if decoder_input_ids is None and decoder_inputs_embeds is None:
raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -1078,6 +1131,7 @@ class FSMTModel(PretrainedFSMTModel):
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
......@@ -1148,6 +1202,8 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
......@@ -1170,8 +1226,10 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
outputs = self.model(
input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_inputs_embeds=decoder_inputs_embeds,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
......@@ -1248,6 +1306,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
def get_output_embeddings(self):
return self.model.decoder.embed_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment