Unverified Commit 367235ee authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Bart can make decoder_input_ids from labels (#6758)

parent b9772897
......@@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"facebook/mbart-large-en-ro",
# See all BART models at https://huggingface.co/models?filter=bart
]
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
BART_START_DOCSTRING = r"""
......@@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
if labels is not None:
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model(
input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment