Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ecf15ebf
"docs/source/vscode:/vscode.git/clone" did not exist on "1d646badbb118cf126a1250b22b246572e07ac4c"
Commit
ecf15ebf
authored
Nov 29, 2019
by
Elad Segal
Committed by
Lysandre Debut
Nov 29, 2019
Browse files
Add ALBERT to AutoClasses
parent
4a666885
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
17 deletions
+46
-17
transformers/configuration_auto.py
transformers/configuration_auto.py
+10
-5
transformers/modeling_auto.py
transformers/modeling_auto.py
+26
-7
transformers/tokenization_auto.py
transformers/tokenization_auto.py
+10
-5
No files found.
transformers/configuration_auto.py
View file @
ecf15ebf
...
@@ -28,6 +28,7 @@ from .configuration_roberta import RobertaConfig
...
@@ -28,6 +28,7 @@ from .configuration_roberta import RobertaConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_camembert
import
CamembertConfig
from
.configuration_camembert
import
CamembertConfig
from
.configuration_albert
import
AlbertConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -44,14 +45,15 @@ class AutoConfig(object):
...
@@ -44,14 +45,15 @@ class AutoConfig(object):
The base model class to instantiate is selected as the first pattern matching
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `bert`: BertConfig (Bert model)
- contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `ctrl` : CTRLConfig (CTRL model)
- contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
This class cannot be instantiated using `__init__()` (throw an error).
"""
"""
...
@@ -67,14 +69,15 @@ class AutoConfig(object):
...
@@ -67,14 +69,15 @@ class AutoConfig(object):
The configuration class to instantiate is selected as the first pattern matching
The configuration class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `bert`: BertConfig (Bert model)
- contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `camembert`: CamembertConfig (CamemBERT model)
- contains `ctrl` : CTRLConfig (CTRL model)
- contains `ctrl` : CTRLConfig (CTRL model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -122,6 +125,8 @@ class AutoConfig(object):
...
@@ -122,6 +125,8 @@ class AutoConfig(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
CamembertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -142,4 +147,4 @@ class AutoConfig(object):
...
@@ -142,4 +147,4 @@ class AutoConfig(object):
return
CTRLConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
CTRLConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'camembert', 'ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta',
'distilbert',
'camembert', 'ctrl'
, 'albert'
"
.
format
(
pretrained_model_name_or_path
))
transformers/modeling_auto.py
View file @
ecf15ebf
...
@@ -28,6 +28,8 @@ from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassifica
...
@@ -28,6 +28,8 @@ from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassifica
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_distilbert
import
DistilBertModel
,
DistilBertForQuestionAnswering
,
DistilBertForMaskedLM
,
DistilBertForSequenceClassification
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_camembert
import
CamembertModel
,
CamembertForMaskedLM
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
from
.modeling_albert
import
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
AlbertForQuestionAnswering
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
@@ -49,15 +51,16 @@ class AutoModel(object):
...
@@ -49,15 +51,16 @@ class AutoModel(object):
The base model class to instantiate is selected as the first pattern matching
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `albert`: AlbertModel (ALBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
- contains `xlm`: XLMModel (XLM model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
...
@@ -73,15 +76,16 @@ class AutoModel(object):
...
@@ -73,15 +76,16 @@ class AutoModel(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `albert`: AlbertModel (ALBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `camembert`: CamembertModel (CamemBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
- contains `xlm`: XLMModel (XLM model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
To train the model, you should first set it back in training mode with `model.train()`
...
@@ -144,6 +148,8 @@ class AutoModel(object):
...
@@ -144,6 +148,8 @@ class AutoModel(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CamembertModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -164,7 +170,7 @@ class AutoModel(object):
...
@@ -164,7 +170,7 @@ class AutoModel(object):
return
CTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CTRLModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta, 'ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta, 'ctrl'
, 'distilbert', 'camembert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
class
AutoModelWithLMHead
(
object
):
class
AutoModelWithLMHead
(
object
):
...
@@ -180,15 +186,16 @@ class AutoModelWithLMHead(object):
...
@@ -180,15 +186,16 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `albert`: AlbertForMaskedLM (ALBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `ctrl`: CTRLLMModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throws an error).
This class cannot be instantiated using `__init__()` (throws an error).
"""
"""
...
@@ -207,6 +214,7 @@ class AutoModelWithLMHead(object):
...
@@ -207,6 +214,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `albert`: AlbertForMaskedLM (ALBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model)
- contains `bert`: BertForMaskedLM (Bert model)
...
@@ -215,6 +223,7 @@ class AutoModelWithLMHead(object):
...
@@ -215,6 +223,7 @@ class AutoModelWithLMHead(object):
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
- contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
To train the model, you should first set it back in training mode with `model.train()`
...
@@ -276,6 +285,8 @@ class AutoModelWithLMHead(object):
...
@@ -276,6 +285,8 @@ class AutoModelWithLMHead(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CamembertForMaskedLM
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -296,7 +307,7 @@ class AutoModelWithLMHead(object):
...
@@ -296,7 +307,7 @@ class AutoModelWithLMHead(object):
return
CTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CTRLLMHeadModel
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta','ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta','ctrl'
, 'distilbert', 'camembert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
class
AutoModelForSequenceClassification
(
object
):
class
AutoModelForSequenceClassification
(
object
):
...
@@ -312,6 +323,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -312,6 +323,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `albert`: AlbertForSequenceClassification (ALBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `bert`: BertForSequenceClassification (Bert model)
...
@@ -335,6 +347,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -335,6 +347,7 @@ class AutoModelForSequenceClassification(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
- contains `albert`: AlbertForSequenceClassification (ALBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
- contains `bert`: BertForSequenceClassification (Bert model)
- contains `bert`: BertForSequenceClassification (Bert model)
...
@@ -402,6 +415,8 @@ class AutoModelForSequenceClassification(object):
...
@@ -402,6 +415,8 @@ class AutoModelForSequenceClassification(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
CamembertForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -414,7 +429,7 @@ class AutoModelForSequenceClassification(object):
...
@@ -414,7 +429,7 @@ class AutoModelForSequenceClassification(object):
return
XLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
XLMForSequenceClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'xlm', 'roberta'
, 'distilbert', 'camembert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
class
AutoModelForQuestionAnswering
(
object
):
class
AutoModelForQuestionAnswering
(
object
):
...
@@ -430,6 +445,7 @@ class AutoModelForQuestionAnswering(object):
...
@@ -430,6 +445,7 @@ class AutoModelForQuestionAnswering(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `albert`: AlbertForQuestionAnswering (ALBERT model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
...
@@ -451,6 +467,7 @@ class AutoModelForQuestionAnswering(object):
...
@@ -451,6 +467,7 @@ class AutoModelForQuestionAnswering(object):
The model class to instantiate is selected as the first pattern matching
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
- contains `albert`: AlbertForQuestionAnswering (ALBERT model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `bert`: BertForQuestionAnswering (Bert model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
- contains `xlm`: XLMForQuestionAnswering (XLM model)
...
@@ -513,6 +530,8 @@ class AutoModelForQuestionAnswering(object):
...
@@ -513,6 +530,8 @@ class AutoModelForQuestionAnswering(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
DistilBertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
return
BertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
BertForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlnet'
in
pretrained_model_name_or_path
:
elif
'xlnet'
in
pretrained_model_name_or_path
:
...
@@ -521,4 +540,4 @@ class AutoModelForQuestionAnswering(object):
...
@@ -521,4 +540,4 @@ class AutoModelForQuestionAnswering(object):
return
XLMForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
return
XLMForQuestionAnswering
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'xlm'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'xlm'
, 'distilbert', 'albert'
"
.
format
(
pretrained_model_name_or_path
))
transformers/tokenization_auto.py
View file @
ecf15ebf
...
@@ -28,6 +28,7 @@ from .tokenization_xlm import XLMTokenizer
...
@@ -28,6 +28,7 @@ from .tokenization_xlm import XLMTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_camembert
import
CamembertTokenizer
from
.tokenization_camembert
import
CamembertTokenizer
from
.tokenization_albert
import
AlbertTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -42,16 +43,17 @@ class AutoTokenizer(object):
...
@@ -42,16 +43,17 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
This class cannot be instantiated using `__init__()` (throw an error).
"""
"""
...
@@ -66,16 +68,17 @@ class AutoTokenizer(object):
...
@@ -66,16 +68,17 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -109,6 +112,8 @@ class AutoTokenizer(object):
...
@@ -109,6 +112,8 @@ class AutoTokenizer(object):
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
...
@@ -129,4 +134,4 @@ class AutoTokenizer(object):
...
@@ -129,4 +134,4 @@ class AutoTokenizer(object):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'camembert', 'ctrl'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta',
'distilbert,'
'camembert', 'ctrl'
, 'albert'
"
.
format
(
pretrained_model_name_or_path
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment