Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
778a263f
Commit
778a263f
authored
Aug 27, 2019
by
LysandreJik
Browse files
GilBert added to AutoModels
parent
74d78bee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
2 deletions
+7
-2
pytorch_transformers/modeling_auto.py
pytorch_transformers/modeling_auto.py
+7
-2
No files found.
pytorch_transformers/modeling_auto.py
View file @
778a263f
...
@@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
...
@@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from
.modeling_xlnet
import
XLNetConfig
,
XLNetModel
from
.modeling_xlnet
import
XLNetConfig
,
XLNetModel
from
.modeling_xlm
import
XLMConfig
,
XLMModel
from
.modeling_xlm
import
XLMConfig
,
XLMModel
from
.modeling_roberta
import
RobertaConfig
,
RobertaModel
from
.modeling_roberta
import
RobertaConfig
,
RobertaModel
from
.modeling_dilbert
import
DilBertConfig
,
DilBertModel
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
@@ -110,7 +111,9 @@ class AutoConfig(object):
...
@@ -110,7 +111,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
assert unused_kwargs == {'foo': False}
"""
"""
if
'roberta'
in
pretrained_model_name_or_path
:
if
'dilbert'
in
pretrained_model_name_or_path
:
return
DilBertconfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
RobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
return
BertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
BertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
...
@@ -225,7 +228,9 @@ class AutoModel(object):
...
@@ -225,7 +228,9 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
"""
if
'roberta'
in
pretrained_model_name_or_path
:
if
'dilbert'
in
pretrained_model_name_or_path
:
return
DilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
RobertaModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
return
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment