"web/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "6db777b3489798740c0ffdc6f503fe4279f2c435"
Unverified Commit 367235ee authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Bart can make decoder_input_ids from labels (#6758)

parent b9772897
...@@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -58,8 +58,8 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/bart-large-cnn", "facebook/bart-large-cnn",
"facebook/bart-large-xsum", "facebook/bart-large-xsum",
"facebook/mbart-large-en-ro", "facebook/mbart-large-en-ro",
# See all BART models at https://huggingface.co/models?filter=bart
] ]
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
BART_START_DOCSTRING = r""" BART_START_DOCSTRING = r"""
...@@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1045,6 +1045,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
if labels is not None: if labels is not None:
use_cache = False use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model( outputs = self.model(
input_ids, input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment