"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "df08f346ce498e1815419ae850887df4c6c8dd9a"
Commit eec5ec80 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[BART] to each its own config + make BART compatible w/ Pipelines

cc @sshleifer
parent 6b1558ba
...@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig ...@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json"
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": _bart_large_url, "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
"bart-large-mnli": _bart_large_url, # fine as same "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json", "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
} }
......
...@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
from .configuration_bart import BartConfig
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat): ...@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat):
""" """
args = ["input_ids", "attention_mask"] args = ["input_ids", "attention_mask"]
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig)): if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
args += ["token_type_ids"] args += ["token_type_ids"]
# PR #1548 (CLI) There is an issue with attention_mask # PR #1548 (CLI) There is an issue with attention_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment