Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
84be482f
Unverified
Commit
84be482f
authored
Jun 18, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 18, 2020
Browse files
AutoTokenizer supports mbart-large-en-ro (#5121)
parent
2db1e2f4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
6 deletions
+12
-6
src/transformers/configuration_auto.py
src/transformers/configuration_auto.py
+2
-1
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+4
-0
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+3
-1
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+3
-4
No files found.
src/transformers/configuration_auto.py
View file @
84be482f
...
...
@@ -19,7 +19,7 @@ import logging
from
collections
import
OrderedDict
from
.configuration_albert
import
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
from
.configuration_bart
import
BART_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BartConfig
from
.configuration_bart
import
BART_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BartConfig
,
MBartConfig
from
.configuration_bert
import
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BertConfig
from
.configuration_camembert
import
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CamembertConfig
from
.configuration_ctrl
import
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
...
...
@@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict(
(
"camembert"
,
CamembertConfig
,),
(
"xlm-roberta"
,
XLMRobertaConfig
,),
(
"marian"
,
MarianConfig
,),
(
"mbart"
,
MBartConfig
,),
(
"bart"
,
BartConfig
,),
(
"reformer"
,
ReformerConfig
,),
(
"longformer"
,
LongformerConfig
,),
...
...
src/transformers/configuration_bart.py
View file @
84be482f
...
...
@@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig):
if
self
.
normalize_before
or
self
.
add_final_layer_norm
or
self
.
scale_embedding
:
logger
.
info
(
"This configuration is a mixture of MBART and BART settings"
)
return
False
class
MBartConfig
(
BartConfig
):
model_type
=
"mbart"
src/transformers/tokenization_auto.py
View file @
84be482f
...
...
@@ -30,6 +30,7 @@ from .configuration_auto import (
FlaubertConfig
,
GPT2Config
,
LongformerConfig
,
MBartConfig
,
OpenAIGPTConfig
,
ReformerConfig
,
RetriBertConfig
,
...
...
@@ -43,7 +44,7 @@ from .configuration_auto import (
from
.configuration_marian
import
MarianConfig
from
.configuration_utils
import
PretrainedConfig
from
.tokenization_albert
import
AlbertTokenizer
from
.tokenization_bart
import
BartTokenizer
from
.tokenization_bart
import
BartTokenizer
,
MBartTokenizer
from
.tokenization_bert
import
BertTokenizer
,
BertTokenizerFast
from
.tokenization_bert_japanese
import
BertJapaneseTokenizer
from
.tokenization_camembert
import
CamembertTokenizer
...
...
@@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
(
DistilBertConfig
,
(
DistilBertTokenizer
,
DistilBertTokenizerFast
)),
(
AlbertConfig
,
(
AlbertTokenizer
,
None
)),
(
CamembertConfig
,
(
CamembertTokenizer
,
None
)),
(
MBartConfig
,
(
MBartTokenizer
,
None
)),
(
XLMRobertaConfig
,
(
XLMRobertaTokenizer
,
None
)),
(
MarianConfig
,
(
MarianTokenizer
,
None
)),
(
BartConfig
,
(
BartTokenizer
,
None
)),
...
...
tests/test_modeling_bart.py
View file @
84be482f
...
...
@@ -31,6 +31,7 @@ if is_torch_available():
from
transformers
import
(
AutoModel
,
AutoModelForSequenceClassification
,
AutoModelForSeq2SeqLM
,
AutoTokenizer
,
BartModel
,
BartForConditionalGeneration
,
...
...
@@ -38,7 +39,6 @@ if is_torch_available():
BartForQuestionAnswering
,
BartConfig
,
BartTokenizer
,
MBartTokenizer
,
BatchEncoding
,
pipeline
,
)
...
...
@@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
@
classmethod
def
setUpClass
(
cls
):
checkpoint_name
=
"facebook/mbart-large-en-ro"
cls
.
tokenizer
=
MBart
Tokenizer
.
from_pretrained
(
checkpoint_name
)
cls
.
tokenizer
=
Auto
Tokenizer
.
from_pretrained
(
checkpoint_name
)
cls
.
pad_token_id
=
1
return
cls
@
cached_property
def
model
(
self
):
"""Only load the model if needed."""
model
=
BartForConditionalGeneration
.
from_pretrained
(
"facebook/mbart-large-en-ro"
).
to
(
torch_device
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"facebook/mbart-large-en-ro"
).
to
(
torch_device
)
if
"cuda"
in
torch_device
:
model
=
model
.
half
()
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment