Unverified Commit c749a543 authored by maximeilluin's avatar maximeilluin Committed by GitHub
Browse files

Added CamembertForQuestionAnswering (#2746)

* Added CamembertForQuestionAnswering

* fixed camembert tokenizer case
parent 5211d333
...@@ -38,6 +38,9 @@ from transformers import ( ...@@ -38,6 +38,9 @@ from transformers import (
BertConfig, BertConfig,
BertForQuestionAnswering, BertForQuestionAnswering,
BertTokenizer, BertTokenizer,
CamembertConfig,
CamembertForQuestionAnswering,
CamembertTokenizer,
DistilBertConfig, DistilBertConfig,
DistilBertForQuestionAnswering, DistilBertForQuestionAnswering,
DistilBertTokenizer, DistilBertTokenizer,
...@@ -70,12 +73,16 @@ except ImportError: ...@@ -70,12 +73,16 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum( ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), (
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig)
),
(), (),
) )
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer),
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer), "roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
...@@ -212,7 +219,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -212,7 +219,7 @@ def train(args, train_dataset, model, tokenizer):
"end_positions": batch[4], "end_positions": batch[4],
} }
if args.model_type in ["xlm", "roberta", "distilbert"]: if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"] del inputs["token_type_ids"]
if args.model_type in ["xlnet", "xlm"]: if args.model_type in ["xlnet", "xlm"]:
...@@ -327,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -327,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"token_type_ids": batch[2], "token_type_ids": batch[2],
} }
if args.model_type in ["xlm", "roberta", "distilbert"]: if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"] del inputs["token_type_ids"]
example_indices = batch[3] example_indices = batch[3]
......
...@@ -221,6 +221,7 @@ if is_torch_available(): ...@@ -221,6 +221,7 @@ if is_torch_available():
CamembertModel, CamembertModel,
CamembertForSequenceClassification, CamembertForSequenceClassification,
CamembertForTokenClassification, CamembertForTokenClassification,
CamembertForQuestionAnswering,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
) )
from .modeling_distilbert import ( from .modeling_distilbert import (
......
...@@ -123,7 +123,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q ...@@ -123,7 +123,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = ( sequence_added_tokens = (
tokenizer.max_len - tokenizer.max_len_single_sentence + 1 tokenizer.max_len - tokenizer.max_len_single_sentence + 1
if "roberta" in str(type(tokenizer)) if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer))
else tokenizer.max_len - tokenizer.max_len_single_sentence else tokenizer.max_len - tokenizer.max_len_single_sentence
) )
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch CamemBERT model. """ """PyTorch CamemBERT model. """
import logging import logging
from .configuration_camembert import CamembertConfig from .configuration_camembert import CamembertConfig
...@@ -23,6 +22,7 @@ from .file_utils import add_start_docstrings ...@@ -23,6 +22,7 @@ from .file_utils import add_start_docstrings
from .modeling_roberta import ( from .modeling_roberta import (
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForMultipleChoice, RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
...@@ -37,7 +37,6 @@ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -37,7 +37,6 @@ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"umberto-wikipedia-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin", "umberto-wikipedia-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin",
} }
CAMEMBERT_START_DOCSTRING = r""" CAMEMBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
...@@ -46,7 +45,8 @@ CAMEMBERT_START_DOCSTRING = r""" ...@@ -46,7 +45,8 @@ CAMEMBERT_START_DOCSTRING = r"""
Parameters: Parameters:
config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration. model. Initializing with a config file does not load the weights associated with the model, only the
configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
""" """
...@@ -121,3 +121,18 @@ class CamembertForTokenClassification(RobertaForTokenClassification): ...@@ -121,3 +121,18 @@ class CamembertForTokenClassification(RobertaForTokenClassification):
config_class = CamembertConfig config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
"""CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits` """,
CAMEMBERT_START_DOCSTRING,
)
class CamembertForQuestionAnswering(RobertaForQuestionAnswering):
"""
This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment