Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a701c9b3
"examples/vscode:/vscode.git/clone" did not exist on "dd52804f5fce0a568ffbb3dc7fd088d2de0a0e56"
Commit
a701c9b3
authored
Oct 11, 2019
by
Lysandre
Browse files
CTRL to tf automodels
parent
d844db40
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
3 deletions
+11
-3
transformers/modeling_tf_auto.py
transformers/modeling_tf_auto.py
+11
-3
No files found.
transformers/modeling_tf_auto.py
View file @
a701c9b3
...
...
@@ -26,6 +26,7 @@ from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSeque
from
.modeling_tf_xlm
import
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
from
.modeling_tf_ctrl
import
TFCTRLModel
,
TFCTRLLMHeadModel
from
.file_utils
import
add_start_docstrings
...
...
@@ -52,6 +53,7 @@ class TFAutoModel(object):
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
- contains `xlnet`: TFXLNetModel (XLNet model)
- contains `xlm`: TFXLMModel (XLM model)
- contains `ctrl`: TFCTRLModel (CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
...
...
@@ -73,7 +75,7 @@ class TFAutoModel(object):
- contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
- contains `xlnet`: TFXLNetModel (XLNet model)
- contains `
xlm
`: TF
XLM
Model (
XLM
model)
- contains `
ctrl
`: TF
CTRL
Model (
CTRL
model)
Params:
pretrained_model_name_or_path: either:
...
...
@@ -147,10 +149,12 @@ class TFAutoModel(object):
return
TFXLNetModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm'
in
pretrained_model_name_or_path
:
return
TFXLMModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'ctrl'
in
pretrained_model_name_or_path
:
return
TFCTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta'
, 'ctrl'
"
.
format
(
pretrained_model_name_or_path
))
class
TFAutoModelWithLMHead
(
object
):
...
...
@@ -173,6 +177,7 @@ class TFAutoModelWithLMHead(object):
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
...
...
@@ -198,6 +203,7 @@ class TFAutoModelWithLMHead(object):
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
Params:
pretrained_model_name_or_path: either:
...
...
@@ -271,10 +277,12 @@ class TFAutoModelWithLMHead(object):
return
TFXLNetLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm'
in
pretrained_model_name_or_path
:
return
TFXLMWithLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'ctrl'
in
pretrained_model_name_or_path
:
return
TFCTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta'
, 'ctrl'
"
.
format
(
pretrained_model_name_or_path
))
class
TFAutoModelForSequenceClassification
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment