Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2827379
"docs/source/vscode:/vscode.git/clone" did not exist on "e3a30e2b99e04e13b8540977775dec2567719e67"
Commit
c2827379
authored
Dec 11, 2019
by
Morgan Funtowicz
Browse files
Add missing DistilBert and Roberta to AutoModelForTokenClassification
parent
b040bff6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
10 deletions
+33
-10
transformers/modeling_auto.py
transformers/modeling_auto.py
+18
-7
transformers/modeling_tf_auto.py
transformers/modeling_tf_auto.py
+15
-3
No files found.
transformers/modeling_auto.py
View file @
c2827379
...
...
@@ -31,8 +31,10 @@ from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from
.modeling_xlnet
import
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
\
XLNetForTokenClassification
from
.modeling_xlm
import
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
\
RobertaForTokenClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
\
DistilBertForSequenceClassification
,
DistilBertForTokenClassification
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
\
CamembertForMultipleChoice
,
CamembertForTokenClassification
from
.modeling_albert
import
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
AlbertForQuestionAnswering
...
...
@@ -720,8 +722,9 @@ class AutoModelForTokenClassification:
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
- isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `xlm` configuration class: XLMModel (XLM model)
- isInstance of `camembert` configuration class: CamembertModel (Camembert model)
- isInstance of `roberta` configuration class: RobertaModel (Roberta model)
Examples::
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
...
...
@@ -729,10 +732,14 @@ class AutoModelForTokenClassification:
"""
if
isinstance
(
config
,
CamembertConfig
):
return
CamembertForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -746,10 +753,10 @@ class AutoModelForTokenClassification:
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForTokenClassification (DistilBERT model)
- contains `
al
bert`:
Al
bertForTokenClassification (
ALBERT
model)
- contains `
camem
bert`:
Camem
bertForTokenClassification (
Camembert
model)
- contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `
xlm`: XLM
ForTokenClassification (
XLM
model)
- contains `
roberta`: Roberta
ForTokenClassification (
Roberta
model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
...
...
@@ -809,10 +816,14 @@ class AutoModelForTokenClassification:
"""
if
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
return
BertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
return
XLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'camembert'
, 'distilbert', 'roberta'
"
.
format
(
pretrained_model_name_or_path
))
transformers/modeling_tf_auto.py
View file @
c2827379
...
...
@@ -30,8 +30,8 @@ from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
from
.modeling_tf_xlnet
import
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
\
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForTokenClassification
from
.modeling_tf_xlm
import
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
TFRobertaForTokenClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
,
TFDistilBertForTokenClassification
from
.modeling_tf_ctrl
import
TFCTRLModel
,
TFCTRLLMHeadModel
from
.file_utils
import
add_start_docstrings
...
...
@@ -687,6 +687,8 @@ class TFAutoModelForTokenClassification:
The model class to instantiate is selected based on the configuration class:
- isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBert model)
- isInstance of `roberta` configuration class: RobteraModel (Roberta model)
Examples::
...
...
@@ -697,6 +699,10 @@ class TFAutoModelForTokenClassification:
return
TFBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -711,6 +717,8 @@ class TFAutoModelForTokenClassification:
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `distilbert`: DistilBertForTokenClassification (DistilBert model)
- contains `roberta`: RobertaForTokenClassification (Roberta model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
...
...
@@ -772,6 +780,10 @@ class TFAutoModelForTokenClassification:
return
TFBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
return
TFXLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
TFDistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
TFRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet'
, 'distilbert', 'roberta'
"
.
format
(
pretrained_model_name_or_path
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment