Unverified Commit 707f7eb1 authored by Silviu Oprea's avatar Silviu Oprea Committed by GitHub
Browse files

Bart: check if decoder_inputs_embeds is set (#13800)



In BartForConditionalGeneration.forward, if labels are provided,
   decoder_input_ids are set to the labels shifted to the right.
   This is problematic: if decoder_inputs_embeds is also set,
   the call to self.model, which eventually gets to BartDecoder.forward,
   will raise an error.
   The fix is quite simple, similar to what is there already in
   BartModel.forward. Mainly, we should not
   compute decoder_input_ids if decoder_inputs_embeds is provided.
Co-authored-by: default avatarSilviu Vlad Oprea <silviuvo@amazon.co.uk>
parent 42137280
...@@ -1291,7 +1291,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1291,7 +1291,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -2501,7 +2501,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2501,7 +2501,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment