Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a701c9b3
Commit
a701c9b3
authored
Oct 11, 2019
by
Lysandre
Browse files
CTRL to tf automodels
parent
d844db40
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
3 deletions
+11
-3
transformers/modeling_tf_auto.py
transformers/modeling_tf_auto.py
+11
-3
No files found.
transformers/modeling_tf_auto.py
View file @
a701c9b3
...
@@ -26,6 +26,7 @@ from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSeque
...
@@ -26,6 +26,7 @@ from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSeque
from
.modeling_tf_xlm
import
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
from
.modeling_tf_xlm
import
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
from
.modeling_tf_ctrl
import
TFCTRLModel
,
TFCTRLLMHeadModel
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
...
@@ -52,6 +53,7 @@ class TFAutoModel(object):
...
@@ -52,6 +53,7 @@ class TFAutoModel(object):
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
- contains `xlnet`: TFXLNetModel (XLNet model)
- contains `xlnet`: TFXLNetModel (XLNet model)
- contains `xlm`: TFXLMModel (XLM model)
- contains `xlm`: TFXLMModel (XLM model)
- contains `ctrl`: TFCTRLModel (CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
...
@@ -73,7 +75,7 @@ class TFAutoModel(object):
...
@@ -73,7 +75,7 @@ class TFAutoModel(object):
- contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
- contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
- contains `xlnet`: TFXLNetModel (XLNet model)
- contains `xlnet`: TFXLNetModel (XLNet model)
- contains `
xlm
`: TF
XLM
Model (
XLM
model)
- contains `
ctrl
`: TF
CTRL
Model (
CTRL
model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -147,10 +149,12 @@ class TFAutoModel(object):
...
@@ -147,10 +149,12 @@ class TFAutoModel(object):
return
TFXLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
TFXLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm'
in
pretrained_model_name_or_path
:
elif
'xlm'
in
pretrained_model_name_or_path
:
return
TFXLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
TFXLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'ctrl'
in
pretrained_model_name_or_path
:
return
TFCTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta'
, 'ctrl'
"
.
format
(
pretrained_model_name_or_path
))
class
TFAutoModelWithLMHead
(
object
):
class
TFAutoModelWithLMHead
(
object
):
...
@@ -173,6 +177,7 @@ class TFAutoModelWithLMHead(object):
...
@@ -173,6 +177,7 @@ class TFAutoModelWithLMHead(object):
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
...
@@ -198,6 +203,7 @@ class TFAutoModelWithLMHead(object):
...
@@ -198,6 +203,7 @@ class TFAutoModelWithLMHead(object):
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -271,10 +277,12 @@ class TFAutoModelWithLMHead(object):
...
@@ -271,10 +277,12 @@ class TFAutoModelWithLMHead(object):
return
TFXLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
TFXLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm'
in
pretrained_model_name_or_path
:
elif
'xlm'
in
pretrained_model_name_or_path
:
return
TFXLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
TFXLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'ctrl'
in
pretrained_model_name_or_path
:
return
TFCTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta'
, 'ctrl'
"
.
format
(
pretrained_model_name_or_path
))
class
TFAutoModelForSequenceClassification
(
object
):
class
TFAutoModelForSequenceClassification
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment