Unverified Commit 0cdfcca2 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1860 from stefan-it/camembert-for-token-classification

[WIP] Add support for CamembertForTokenClassification
parents e70cdf08 56c84863
...@@ -37,6 +37,7 @@ from transformers import AdamW, get_linear_schedule_with_warmup ...@@ -37,6 +37,7 @@ from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -47,7 +48,8 @@ ALL_MODELS = sum( ...@@ -47,7 +48,8 @@ ALL_MODELS = sum(
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer), "bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer) "distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
} }
......
...@@ -100,6 +100,7 @@ if is_torch_available(): ...@@ -100,6 +100,7 @@ if is_torch_available():
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_camembert import (CamembertForMaskedLM, CamembertModel, from .modeling_camembert import (CamembertForMaskedLM, CamembertModel,
CamembertForSequenceClassification, CamembertForMultipleChoice, CamembertForSequenceClassification, CamembertForMultipleChoice,
CamembertForTokenClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP) CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
......
...@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function, ...@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
import logging import logging
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
from .configuration_camembert import CamembertConfig from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice): ...@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
""" """
config_class = CamembertConfig config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""CamemBERT Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
class CamembertForTokenClassification(RobertaForTokenClassification):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForTokenClassification.from_pretrained('camembert-base')
input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -16,9 +16,14 @@ ...@@ -16,9 +16,14 @@
from __future__ import (absolute_import, division, print_function, from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
import logging
import os
from shutil import copyfile
import sentencepiece as spm import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
...@@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file)) self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s> # sentencepiece vocabulary (this is the case for <s> and </s>
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3} self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
...@@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
if index in self.fairseq_ids_to_tokens: if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment