Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ecf15ebf
Commit
ecf15ebf
authored
Nov 29, 2019
by
Elad Segal
Committed by
Lysandre Debut
Nov 29, 2019
Browse files
Add ALBERT to AutoClasses
parent
4a666885
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
17 deletions
+46
-17
transformers/configuration_auto.py
transformers/configuration_auto.py
+10
-5
transformers/modeling_auto.py
transformers/modeling_auto.py
+26
-7
transformers/tokenization_auto.py
transformers/tokenization_auto.py
+10
-5
No files found.
transformers/configuration_auto.py
View file @
ecf15ebf
...
@@ -28,6 +28,7 @@ from .configuration_roberta import RobertaConfig
...
@@ -28,6 +28,7 @@ from .configuration_roberta import RobertaConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_camembert
import
CamembertConfig
from
.configuration_camembert
import
CamembertConfig
from
.configuration_albert
import
AlbertConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -44,14 +45,15 @@ class AutoConfig(object):
...
@@ -44,14 +45,15 @@ class AutoConfig(object):
The base model class to instantiate is selected as the first pattern matching
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `bert`: BertConfig (Bert model)
- contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `ctrl` : CTRLConfig (CTRL model)
- contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
This class cannot be instantiated using `__init__()` (throw an error).
"""
"""
...
@@ -67,14 +69,15 @@ class AutoConfig(object):
...
@@ -67,14 +69,15 @@ class AutoConfig(object):
The configuration class to instantiate is selected as the first pattern matching
The configuration class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `bert`: BertConfig (Bert model)
- contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `ctrl` : CTRLConfig (CTRL model)
- contains `ctrl` : CTRLConfig (CTRL model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -122,6 +125,8 @@ class AutoConfig(object):
...
@@ -122,6 +125,8 @@ class AutoConfig(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
CamembertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -142,4 +147,4 @@ class AutoConfig(object):
...
@@ -142,4 +147,4 @@ class AutoConfig(object):
return
CTRLConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
CTRLConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'camembert', 'ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta',
'distilbert',
'camembert', 'ctrl'
, 'albert'
"
.
format
(
pretrained_model_name_or_path
))
transformers/modeling_auto.py
View file @
ecf15ebf
...
@@ -28,6 +28,8 @@ from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassifica
...
@@ -28,6 +28,8 @@ from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassifica
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_albert
import
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
AlbertForQuestionAnswering
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
@@ -49,15 +51,16 @@ class AutoModel(object):
...
@@ -49,15 +51,16 @@ class AutoModel(object):
The base model class to instantiate is selected as the first pattern matching
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `albert`: AlbertModel (ALBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
- contains `xlm`: XLMModel (XLM model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
...
@@ -73,15 +76,16 @@ class AutoModel(object):
...
@@ -73,15 +76,16 @@ class AutoModel(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `albert`: AlbertModel (ALBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
- contains `xlm`: XLMModel (XLM model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
To train the model, you should first set it back in training mode with `model.train()`
...
@@ -144,6 +148,8 @@ class AutoModel(object):
...
@@ -144,6 +148,8 @@ class AutoModel(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CamembertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -164,7 +170,7 @@ class AutoModel(object):
...
@@ -164,7 +170,7 @@ class AutoModel(object):
return
CTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta, 'ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta, 'ctrl'
, 'distilbert', 'camembert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
class
AutoModelWithLMHead
(
object
):
class
AutoModelWithLMHead
(
object
):
...
@@ -180,15 +186,16 @@ class AutoModelWithLMHead(object):
...
@@ -180,15 +186,16 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `albert`: AlbertForMaskedLM (ALBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `ctrl`: CTRLLMModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
...
@@ -207,6 +214,7 @@ class AutoModelWithLMHead(object):
...
@@ -207,6 +214,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `albert`: AlbertForMaskedLM (ALBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `bert`: BertForMaskedLM (Bert model)
...
@@ -215,6 +223,7 @@ class AutoModelWithLMHead(object):
...
@@ -215,6 +223,7 @@ class AutoModelWithLMHead(object):
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
To train the model, you should first set it back in training mode with `model.train()`
...
@@ -276,6 +285,8 @@ class AutoModelWithLMHead(object):
...
@@ -276,6 +285,8 @@ class AutoModelWithLMHead(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -296,7 +307,7 @@ class AutoModelWithLMHead(object):
...
@@ -296,7 +307,7 @@ class AutoModelWithLMHead(object):
return
CTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta','ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta','ctrl'
, 'distilbert', 'camembert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
class
AutoModelForSequenceClassification
(
object
):
class
AutoModelForSequenceClassification
(
object
):
...
@@ -312,6 +323,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -312,6 +323,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `albert`: AlbertForSequenceClassification (ALBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `bert`: BertForSequenceClassification (Bert model)
...
@@ -335,6 +347,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -335,6 +347,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `albert`: AlbertForSequenceClassification (ALBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `bert`: BertForSequenceClassification (Bert model)
...
@@ -402,6 +415,8 @@ class AutoModelForSequenceClassification(object):
...
@@ -402,6 +415,8 @@ class AutoModelForSequenceClassification(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -414,7 +429,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -414,7 +429,7 @@ class AutoModelForSequenceClassification(object):
return
XLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
XLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'xlm', 'roberta'
, 'distilbert', 'camembert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
class
AutoModelForQuestionAnswering
(
object
):
class
AutoModelForQuestionAnswering
(
object
):
...
@@ -430,6 +445,7 @@ class AutoModelForQuestionAnswering(object):
...
@@ -430,6 +445,7 @@ class AutoModelForQuestionAnswering(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `albert`: AlbertForQuestionAnswering (ALBERT model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
...
@@ -451,6 +467,7 @@ class AutoModelForQuestionAnswering(object):
...
@@ -451,6 +467,7 @@ class AutoModelForQuestionAnswering(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `albert`: AlbertForQuestionAnswering (ALBERT model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
...
@@ -513,6 +530,8 @@ class AutoModelForQuestionAnswering(object):
...
@@ -513,6 +530,8 @@ class AutoModelForQuestionAnswering(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
return
BertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
BertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
elif
'xlnet'
in
pretrained_model_name_or_path
:
...
@@ -521,4 +540,4 @@ class AutoModelForQuestionAnswering(object):
...
@@ -521,4 +540,4 @@ class AutoModelForQuestionAnswering(object):
return
XLMForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
XLMForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'xlm'
, 'distilbert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
transformers/tokenization_auto.py
View file @
ecf15ebf
...
@@ -28,6 +28,7 @@ from .tokenization_xlm import XLMTokenizer
...
@@ -28,6 +28,7 @@ from .tokenization_xlm import XLMTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_camembert
import
CamembertTokenizer
from
.tokenization_camembert
import
CamembertTokenizer
from
.tokenization_albert
import
AlbertTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -42,16 +43,17 @@ class AutoTokenizer(object):
...
@@ -42,16 +43,17 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
This class cannot be instantiated using `__init__()` (throw an error).
"""
"""
...
@@ -66,16 +68,17 @@ class AutoTokenizer(object):
...
@@ -66,16 +68,17 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -109,6 +112,8 @@ class AutoTokenizer(object):
...
@@ -109,6 +112,8 @@ class AutoTokenizer(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -129,4 +134,4 @@ class AutoTokenizer(object):
...
@@ -129,4 +134,4 @@ class AutoTokenizer(object):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'camembert', 'ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta',
'distilbert,'
'camembert', 'ctrl'
, 'albert'
"
.
format
(
pretrained_model_name_or_path
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment