Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eec5ec80
Commit
eec5ec80
authored
Mar 02, 2020
by
Julien Chaumond
Browse files
[BART] to each its own config + make BART compatible w/ Pipelines
cc @sshleifer
parent
6b1558ba
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+2
-3
src/transformers/pipelines.py
src/transformers/pipelines.py
+2
-1
No files found.
src/transformers/configuration_bart.py
View file @
eec5ec80
...
@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig
...
@@ -22,10 +22,9 @@ from .configuration_utils import PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_bart_large_url
=
"https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json"
BART_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
BART_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"bart-large"
:
_
bart
_
large
_url
,
"bart-large"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
bart
-
large
/config.json"
,
"bart-large-mnli"
:
_
bart
_
large
_url
,
# fine as same
"bart-large-mnli"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
bart
-
large
-mnli/config.json"
,
"bart-large-cnn"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json"
,
"bart-large-cnn"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json"
,
}
}
...
...
src/transformers/pipelines.py
View file @
eec5ec80
...
@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union
...
@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AutoConfig
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AutoConfig
from
.configuration_bart
import
BartConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_roberta
import
RobertaConfig
from
.configuration_roberta
import
RobertaConfig
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
...
@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat):
...
@@ -427,7 +428,7 @@ class Pipeline(_ScikitCompat):
"""
"""
args
=
[
"input_ids"
,
"attention_mask"
]
args
=
[
"input_ids"
,
"attention_mask"
]
if
not
isinstance
(
self
.
model
.
config
,
(
DistilBertConfig
,
XLMConfig
,
RobertaConfig
)):
if
not
isinstance
(
self
.
model
.
config
,
(
DistilBertConfig
,
XLMConfig
,
RobertaConfig
,
BartConfig
)):
args
+=
[
"token_type_ids"
]
args
+=
[
"token_type_ids"
]
# PR #1548 (CLI) There is an issue with attention_mask
# PR #1548 (CLI) There is an issue with attention_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment