"docs/source/vscode:/vscode.git/clone" did not exist on "e3a30e2b99e04e13b8540977775dec2567719e67"
Commit c2827379 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Add missing DistilBert and Roberta to AutoModelForTokenClassification

parent b040bff6
......@@ -31,8 +31,10 @@ from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \
XLNetForTokenClassification
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, \
RobertaForTokenClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, \
DistilBertForSequenceClassification, DistilBertForTokenClassification
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \
CamembertForMultipleChoice, CamembertForTokenClassification
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering
......@@ -720,8 +722,9 @@ class AutoModelForTokenClassification:
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
- isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `xlm` configuration class: XLMModel (XLM model)
- isInstance of `camembert` configuration class: CamembertModel (Camembert model)
- isInstance of `roberta` configuration class: RobertaModel (Roberta model)
Examples::
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
......@@ -729,10 +732,14 @@ class AutoModelForTokenClassification:
"""
if isinstance(config, CamembertConfig):
return CamembertForTokenClassification(config)
elif isinstance(config, DistilBertConfig):
return DistilBertForTokenClassification(config)
elif isinstance(config, BertConfig):
return BertForTokenClassification(config)
elif isinstance(config, XLNetConfig):
return XLNetForTokenClassification(config)
elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod
......@@ -746,10 +753,10 @@ class AutoModelForTokenClassification:
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForTokenClassification (DistilBERT model)
- contains `albert`: AlbertForTokenClassification (ALBERT model)
- contains `camembert`: CamembertForTokenClassification (Camembert model)
- contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `xlm`: XLMForTokenClassification (XLM model)
- contains `roberta`: RobertaForTokenClassification (Roberta model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
......@@ -809,10 +816,14 @@ class AutoModelForTokenClassification:
"""
if 'camembert' in pretrained_model_name_or_path:
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert'".format(pretrained_model_name_or_path))
"'bert', 'xlnet', 'camembert', 'distilbert', 'roberta'".format(pretrained_model_name_or_path))
......@@ -30,8 +30,8 @@ from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, \
TFXLNetForQuestionAnsweringSimple, TFXLNetForTokenClassification
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TFRobertaForTokenClassification
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, TFDistilBertForTokenClassification
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel
from .file_utils import add_start_docstrings
......@@ -687,6 +687,8 @@ class TFAutoModelForTokenClassification:
The model class to instantiate is selected based on the configuration class:
- isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBert model)
- isInstance of `roberta` configuration class: RobteraModel (Roberta model)
Examples::
......@@ -697,6 +699,10 @@ class TFAutoModelForTokenClassification:
return TFBertForTokenClassification(config)
elif isinstance(config, XLNetConfig):
return TFXLNetForTokenClassification(config)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForTokenClassification(config)
elif isinstance(config, RobertaConfig):
return TFRobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod
......@@ -711,6 +717,8 @@ class TFAutoModelForTokenClassification:
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `distilbert`: DistilBertForTokenClassification (DistilBert model)
- contains `roberta`: RobertaForTokenClassification (Roberta model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
......@@ -772,6 +780,10 @@ class TFAutoModelForTokenClassification:
return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet'".format(pretrained_model_name_or_path))
"'bert', 'xlnet', 'distilbert', 'roberta'".format(pretrained_model_name_or_path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment