Unverified Commit c749a543 authored by maximeilluin's avatar maximeilluin Committed by GitHub
Browse files

Added CamembertForQuestionAnswering (#2746)

* Added CamembertForQuestionAnswering

* fixed camembert tokenizer case
parent 5211d333
......@@ -38,6 +38,9 @@ from transformers import (
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
CamembertConfig,
CamembertForQuestionAnswering,
CamembertTokenizer,
DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
......@@ -70,12 +73,16 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)),
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig)
),
(),
)
MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer),
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
......@@ -212,7 +219,7 @@ def train(args, train_dataset, model, tokenizer):
"end_positions": batch[4],
}
if args.model_type in ["xlm", "roberta", "distilbert"]:
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"]
if args.model_type in ["xlnet", "xlm"]:
......@@ -327,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"token_type_ids": batch[2],
}
if args.model_type in ["xlm", "roberta", "distilbert"]:
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"]
example_indices = batch[3]
......
......@@ -221,6 +221,7 @@ if is_torch_available():
CamembertModel,
CamembertForSequenceClassification,
CamembertForTokenClassification,
CamembertForQuestionAnswering,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_distilbert import (
......
......@@ -123,7 +123,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = (
tokenizer.max_len - tokenizer.max_len_single_sentence + 1
if "roberta" in str(type(tokenizer))
if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer))
else tokenizer.max_len - tokenizer.max_len_single_sentence
)
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
......
......@@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch CamemBERT model. """
import logging
from .configuration_camembert import CamembertConfig
......@@ -23,6 +22,7 @@ from .file_utils import add_start_docstrings
from .modeling_roberta import (
RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
......@@ -37,7 +37,6 @@ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"umberto-wikipedia-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin",
}
CAMEMBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
......@@ -46,7 +45,8 @@ CAMEMBERT_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
model. Initializing with a config file does not load the weights associated with the model, only the
configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
......@@ -121,3 +121,18 @@ class CamembertForTokenClassification(RobertaForTokenClassification):
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings(
"""CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits` """,
CAMEMBERT_START_DOCSTRING,
)
class CamembertForQuestionAnswering(RobertaForQuestionAnswering):
"""
This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment