Commit c2827379 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Add missing DistilBert and Roberta to AutoModelForTokenClassification

parent b040bff6
...@@ -31,8 +31,10 @@ from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel ...@@ -31,8 +31,10 @@ from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, \
XLNetForTokenClassification XLNetForTokenClassification
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, \
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification RobertaForTokenClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, \
DistilBertForSequenceClassification, DistilBertForTokenClassification
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \ from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, \
CamembertForMultipleChoice, CamembertForTokenClassification CamembertForMultipleChoice, CamembertForTokenClassification
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering
...@@ -720,7 +722,8 @@ class AutoModelForTokenClassification: ...@@ -720,7 +722,8 @@ class AutoModelForTokenClassification:
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model) - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
- isInstance of `bert` configuration class: BertModel (Bert model) - isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model) - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `xlm` configuration class: XLMModel (XLM model) - isInstance of `camembert` configuration class: CamembertModel (Camembert model)
- isInstance of `roberta` configuration class: RobertaModel (Roberta model)
Examples:: Examples::
...@@ -729,10 +732,14 @@ class AutoModelForTokenClassification: ...@@ -729,10 +732,14 @@ class AutoModelForTokenClassification:
""" """
if isinstance(config, CamembertConfig): if isinstance(config, CamembertConfig):
return CamembertForTokenClassification(config) return CamembertForTokenClassification(config)
elif isinstance(config, DistilBertConfig):
return DistilBertForTokenClassification(config)
elif isinstance(config, BertConfig): elif isinstance(config, BertConfig):
return BertForTokenClassification(config) return BertForTokenClassification(config)
elif isinstance(config, XLNetConfig): elif isinstance(config, XLNetConfig):
return XLNetForTokenClassification(config) return XLNetForTokenClassification(config)
elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -746,10 +753,10 @@ class AutoModelForTokenClassification: ...@@ -746,10 +753,10 @@ class AutoModelForTokenClassification:
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertForTokenClassification (DistilBERT model) - contains `distilbert`: DistilBertForTokenClassification (DistilBERT model)
- contains `albert`: AlbertForTokenClassification (ALBERT model) - contains `camembert`: CamembertForTokenClassification (Camembert model)
- contains `bert`: BertForTokenClassification (Bert model) - contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model) - contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `xlm`: XLMForTokenClassification (XLM model) - contains `roberta`: RobertaForTokenClassification (Roberta model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()` To train the model, you should first set it back in training mode with `model.train()`
...@@ -809,10 +816,14 @@ class AutoModelForTokenClassification: ...@@ -809,10 +816,14 @@ class AutoModelForTokenClassification:
""" """
if 'camembert' in pretrained_model_name_or_path: if 'camembert' in pretrained_model_name_or_path:
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif 'xlnet' in pretrained_model_name_or_path:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert'".format(pretrained_model_name_or_path)) "'bert', 'xlnet', 'camembert', 'distilbert', 'roberta'".format(pretrained_model_name_or_path))
...@@ -30,8 +30,8 @@ from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel ...@@ -30,8 +30,8 @@ from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, \ from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, \
TFXLNetForQuestionAnsweringSimple, TFXLNetForTokenClassification TFXLNetForQuestionAnsweringSimple, TFXLNetForTokenClassification
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TFRobertaForTokenClassification
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, TFDistilBertForTokenClassification
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -687,6 +687,8 @@ class TFAutoModelForTokenClassification: ...@@ -687,6 +687,8 @@ class TFAutoModelForTokenClassification:
The model class to instantiate is selected based on the configuration class: The model class to instantiate is selected based on the configuration class:
- isInstance of `bert` configuration class: BertModel (Bert model) - isInstance of `bert` configuration class: BertModel (Bert model)
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model) - isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBert model)
- isInstance of `roberta` configuration class: RobteraModel (Roberta model)
Examples:: Examples::
...@@ -697,6 +699,10 @@ class TFAutoModelForTokenClassification: ...@@ -697,6 +699,10 @@ class TFAutoModelForTokenClassification:
return TFBertForTokenClassification(config) return TFBertForTokenClassification(config)
elif isinstance(config, XLNetConfig): elif isinstance(config, XLNetConfig):
return TFXLNetForTokenClassification(config) return TFXLNetForTokenClassification(config)
elif isinstance(config, DistilBertConfig):
return TFDistilBertForTokenClassification(config)
elif isinstance(config, RobertaConfig):
return TFRobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -711,6 +717,8 @@ class TFAutoModelForTokenClassification: ...@@ -711,6 +717,8 @@ class TFAutoModelForTokenClassification:
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertForTokenClassification (Bert model) - contains `bert`: BertForTokenClassification (Bert model)
- contains `xlnet`: XLNetForTokenClassification (XLNet model) - contains `xlnet`: XLNetForTokenClassification (XLNet model)
- contains `distilbert`: DistilBertForTokenClassification (DistilBert model)
- contains `roberta`: RobertaForTokenClassification (Roberta model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()` To train the model, you should first set it back in training mode with `model.train()`
...@@ -772,6 +780,10 @@ class TFAutoModelForTokenClassification: ...@@ -772,6 +780,10 @@ class TFAutoModelForTokenClassification:
return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif 'xlnet' in pretrained_model_name_or_path:
return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFXLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet'".format(pretrained_model_name_or_path)) "'bert', 'xlnet', 'distilbert', 'roberta'".format(pretrained_model_name_or_path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment