Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2827379
Commit
c2827379
authored
Dec 11, 2019
by
Morgan Funtowicz
Browse files
Add missing DistilBert and Roberta to AutoModelForTokenClassification
parent
b040bff6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
10 deletions
+33
-10
transformers/modeling_auto.py
transformers/modeling_auto.py
+18
-7
transformers/modeling_tf_auto.py
transformers/modeling_tf_auto.py
+15
-3
No files found.
transformers/modeling_auto.py
View file @
c2827379
...
...
@@ -31,8 +31,10 @@ from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from
.modeling_xlnet
import
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
\
XLNetForTokenClassification
from
.modeling_xlm
import
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
\
RobertaForTokenClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
\
DistilBertForSequenceClassification
,
DistilBertForTokenClassification
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
\
CamembertForMultipleChoice
,
CamembertForTokenClassification
from
.modeling_albert
import
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
AlbertForQuestionAnswering
...
...
@@ -720,8 +722,9 @@ class AutoModelForTokenClassification:
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
- isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `xlm` configuration class: XLMModel (XLM model)
- isInstance of `camembert` configuration class: CamembertModel (Camembert model)
- isInstance of `roberta` configuration class: RobertaModel (Roberta model)
Examples::
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
...
...
@@ -729,10 +732,14 @@ class AutoModelForTokenClassification:
"""
if
isinstance
(
config
,
CamembertConfig
):
return
CamembertForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
DistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
BertConfig
):
return
BertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
XLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -746,10 +753,10 @@ class AutoModelForTokenClassification:
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForTokenClassification (DistilBERT model)
- contains `
al
bert`:
Al
bertForTokenClassification (
ALBERT
model)
- contains `
camem
bert`:
Camem
bertForTokenClassification (
Camembert
model)
- contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `
xlm`: XLM
ForTokenClassification (
XLM
model)
- contains `
roberta`: Roberta
ForTokenClassification (
Roberta
model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
...
...
@@ -809,10 +816,14 @@ class AutoModelForTokenClassification:
"""
if
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
return
BertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
return
XLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'camembert'
, 'distilbert', 'roberta'
"
.
format
(
pretrained_model_name_or_path
))
transformers/modeling_tf_auto.py
View file @
c2827379
...
...
@@ -30,8 +30,8 @@ from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
from
.modeling_tf_xlnet
import
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
\
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForTokenClassification
from
.modeling_tf_xlm
import
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
from
.modeling_tf_roberta
import
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
TFRobertaForTokenClassification
from
.modeling_tf_distilbert
import
TFDistilBertModel
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
,
TFDistilBertForTokenClassification
from
.modeling_tf_ctrl
import
TFCTRLModel
,
TFCTRLLMHeadModel
from
.file_utils
import
add_start_docstrings
...
...
@@ -687,6 +687,8 @@ class TFAutoModelForTokenClassification:
The model class to instantiate is selected based on the configuration class:
- isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBert model)
- isInstance of `roberta` configuration class: RobteraModel (Roberta model)
Examples::
...
...
@@ -697,6 +699,10 @@ class TFAutoModelForTokenClassification:
return
TFBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLNetConfig
):
return
TFXLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
DistilBertConfig
):
return
TFDistilBertForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
TFRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -711,6 +717,8 @@ class TFAutoModelForTokenClassification:
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `distilbert`: DistilBertForTokenClassification (DistilBert model)
- contains `roberta`: RobertaForTokenClassification (Roberta model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
...
...
@@ -772,6 +780,10 @@ class TFAutoModelForTokenClassification:
return
TFBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
return
TFXLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
TFDistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
TFRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet'
, 'distilbert', 'roberta'
"
.
format
(
pretrained_model_name_or_path
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment