Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a9f1fc6c
Unverified
Commit
a9f1fc6c
authored
Jun 15, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 15, 2020
Browse files
Add bart-base (#5014)
parent
7b685f52
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
0 deletions
+20
-0
docs/source/pretrained_models.rst
docs/source/pretrained_models.rst
+2
-0
src/transformers/tokenization_bart.py
src/transformers/tokenization_bart.py
+1
-0
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+17
-0
No files found.
docs/source/pretrained_models.rst
View file @
a9f1fc6c
...
...
@@ -278,6 +278,8 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``facebook/bart-base`` | | 12-layer, 768-hidden, 16-heads, 139M parameters |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
...
...
src/transformers/tokenization_bart.py
View file @
a9f1fc6c
...
...
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
vocab_url
=
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url
=
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models
=
[
"facebook/bart-base"
,
"facebook/bart-large"
,
"facebook/bart-large-mnli"
,
"facebook/bart-large-cnn"
,
...
...
tests/test_modeling_bart.py
View file @
a9f1fc6c
...
...
@@ -40,6 +40,7 @@ if is_torch_available():
BartTokenizer
,
MBartTokenizer
,
BatchEncoding
,
pipeline
,
)
from
transformers.modeling_bart
import
(
BART_PRETRAINED_MODEL_ARCHIVE_LIST
,
...
...
@@ -565,6 +566,22 @@ class BartModelIntegrationTests(unittest.TestCase):
)
self
.
assertTrue
(
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
TOLERANCE
))
@
slow
def
test_bart_base_mask_filling
(
self
):
pbase
=
pipeline
(
task
=
"fill-mask"
,
model
=
"facebook/bart-base"
)
src_text
=
[
" I went to the <mask>."
]
results
=
[
x
[
"token_str"
]
for
x
in
pbase
(
src_text
)]
expected_results
=
[
"Ġbathroom"
,
"Ġrestroom"
,
"Ġhospital"
,
"Ġkitchen"
,
"Ġcar"
]
self
.
assertListEqual
(
results
,
expected_results
)
@
slow
def
test_bart_large_mask_filling
(
self
):
pbase
=
pipeline
(
task
=
"fill-mask"
,
model
=
"facebook/bart-large"
)
src_text
=
[
" I went to the <mask>."
]
results
=
[
x
[
"token_str"
]
for
x
in
pbase
(
src_text
)]
expected_results
=
[
"Ġbathroom"
,
"Ġgym"
,
"Ġwrong"
,
"Ġmovies"
,
"Ġhospital"
]
self
.
assertListEqual
(
results
,
expected_results
)
@
slow
def
test_mnli_inference
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment