Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0cdfcca2
Unverified
Commit
0cdfcca2
authored
Nov 21, 2019
by
Thomas Wolf
Committed by
GitHub
Nov 21, 2019
Browse files
Merge pull request #1860 from stefan-it/camembert-for-token-classification
[WIP] Add support for CamembertForTokenClassification
parents
e70cdf08
56c84863
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
2 deletions
+61
-2
examples/run_ner.py
examples/run_ner.py
+3
-1
transformers/__init__.py
transformers/__init__.py
+1
-0
transformers/modeling_camembert.py
transformers/modeling_camembert.py
+37
-1
transformers/tokenization_camembert.py
transformers/tokenization_camembert.py
+20
-0
No files found.
examples/run_ner.py
View file @
0cdfcca2
...
...
@@ -37,6 +37,7 @@ from transformers import AdamW, get_linear_schedule_with_warmup
from
transformers
import
WEIGHTS_NAME
,
BertConfig
,
BertForTokenClassification
,
BertTokenizer
from
transformers
import
RobertaConfig
,
RobertaForTokenClassification
,
RobertaTokenizer
from
transformers
import
DistilBertConfig
,
DistilBertForTokenClassification
,
DistilBertTokenizer
from
transformers
import
CamembertConfig
,
CamembertForTokenClassification
,
CamembertTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -47,7 +48,8 @@ ALL_MODELS = sum(
MODEL_CLASSES
=
{
"bert"
:
(
BertConfig
,
BertForTokenClassification
,
BertTokenizer
),
"roberta"
:
(
RobertaConfig
,
RobertaForTokenClassification
,
RobertaTokenizer
),
"distilbert"
:
(
DistilBertConfig
,
DistilBertForTokenClassification
,
DistilBertTokenizer
)
"distilbert"
:
(
DistilBertConfig
,
DistilBertForTokenClassification
,
DistilBertTokenizer
),
"camembert"
:
(
CamembertConfig
,
CamembertForTokenClassification
,
CamembertTokenizer
),
}
...
...
transformers/__init__.py
View file @
0cdfcca2
...
...
@@ -100,6 +100,7 @@ if is_torch_available():
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_camembert
import
(
CamembertForMaskedLM
,
CamembertModel
,
CamembertForSequenceClassification
,
CamembertForMultipleChoice
,
CamembertForTokenClassification
,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_encoder_decoder
import
PreTrainedEncoderDecoder
,
Model2Model
...
...
transformers/modeling_camembert.py
View file @
0cdfcca2
...
...
@@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
import
logging
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
from
.modeling_roberta
import
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
,
RobertaForTokenClassification
from
.configuration_camembert
import
CamembertConfig
from
.file_utils
import
add_start_docstrings
...
...
@@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
"""
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@
add_start_docstrings
(
"""CamemBERT Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
CAMEMBERT_START_DOCSTRING
,
CAMEMBERT_INPUTS_DOCSTRING
)
class
CamembertForTokenClassification
(
RobertaForTokenClassification
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForTokenClassification.from_pretrained('camembert-base')
input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
transformers/tokenization_camembert.py
View file @
0cdfcca2
...
...
@@ -16,9 +16,14 @@
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
logging
import
os
from
shutil
import
copyfile
import
sentencepiece
as
spm
from
transformers.tokenization_utils
import
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'sentencepiece.bpe.model'
}
...
...
@@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
self
.
max_len_sentences_pair
=
self
.
max_len
-
4
# take into account special tokens
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
str
(
vocab_file
))
self
.
vocab_file
=
vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s>
self
.
fairseq_tokens_to_ids
=
{
'<s>NOTUSED'
:
0
,
'<pad>'
:
1
,
'</s>NOTUSED'
:
2
,
'<unk>'
:
3
}
...
...
@@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
def
save_vocabulary
(
self
,
save_directory
):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment