Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c749a543
Unverified
Commit
c749a543
authored
Feb 21, 2020
by
maximeilluin
Committed by
GitHub
Feb 21, 2020
Browse files
Added CamembertForQuestionAnswering (#2746)
* Added CamembertForQuestionAnswering * fixed camembert tokenizer case
parent
5211d333
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
7 deletions
+30
-7
examples/run_squad.py
examples/run_squad.py
+10
-3
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/data/processors/squad.py
src/transformers/data/processors/squad.py
+1
-1
src/transformers/modeling_camembert.py
src/transformers/modeling_camembert.py
+18
-3
No files found.
examples/run_squad.py
View file @
c749a543
...
...
@@ -38,6 +38,9 @@ from transformers import (
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
,
CamembertConfig
,
CamembertForQuestionAnswering
,
CamembertTokenizer
,
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
,
...
...
@@ -70,12 +73,16 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
(
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
RobertaConfig
,
XLNetConfig
,
XLMConfig
)),
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
CamembertConfig
,
RobertaConfig
,
XLNetConfig
,
XLMConfig
)
),
(),
)
MODEL_CLASSES
=
{
"bert"
:
(
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
),
"camembert"
:
(
CamembertConfig
,
CamembertForQuestionAnswering
,
CamembertTokenizer
),
"roberta"
:
(
RobertaConfig
,
RobertaForQuestionAnswering
,
RobertaTokenizer
),
"xlnet"
:
(
XLNetConfig
,
XLNetForQuestionAnswering
,
XLNetTokenizer
),
"xlm"
:
(
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
),
...
...
@@ -212,7 +219,7 @@ def train(args, train_dataset, model, tokenizer):
"end_positions"
:
batch
[
4
],
}
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]:
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
,
"camembert"
]:
del
inputs
[
"token_type_ids"
]
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
...
...
@@ -327,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"token_type_ids"
:
batch
[
2
],
}
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]:
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
,
"camembert"
]:
del
inputs
[
"token_type_ids"
]
example_indices
=
batch
[
3
]
...
...
src/transformers/__init__.py
View file @
c749a543
...
...
@@ -221,6 +221,7 @@ if is_torch_available():
CamembertModel
,
CamembertForSequenceClassification
,
CamembertForTokenClassification
,
CamembertForQuestionAnswering
,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
from
.modeling_distilbert
import
(
...
...
src/transformers/data/processors/squad.py
View file @
c749a543
...
...
@@ -123,7 +123,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
(
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
+
1
if
"roberta"
in
str
(
type
(
tokenizer
))
if
"roberta"
in
str
(
type
(
tokenizer
))
or
"camembert"
in
str
(
type
(
tokenizer
))
else
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
)
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
...
...
src/transformers/modeling_camembert.py
View file @
c749a543
...
...
@@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch CamemBERT model. """
import
logging
from
.configuration_camembert
import
CamembertConfig
...
...
@@ -23,6 +22,7 @@ from .file_utils import add_start_docstrings
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaForMultipleChoice
,
RobertaForQuestionAnswering
,
RobertaForSequenceClassification
,
RobertaForTokenClassification
,
RobertaModel
,
...
...
@@ -37,7 +37,6 @@ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"umberto-wikipedia-uncased-v1"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin"
,
}
CAMEMBERT_START_DOCSTRING
=
r
"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
...
...
@@ -46,7 +45,8 @@ CAMEMBERT_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
model. Initializing with a config file does not load the weights associated with the model, only the
configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
...
...
@@ -121,3 +121,18 @@ class CamembertForTokenClassification(RobertaForTokenClassification):
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@
add_start_docstrings
(
"""CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits` """
,
CAMEMBERT_START_DOCSTRING
,
)
class
CamembertForQuestionAnswering
(
RobertaForQuestionAnswering
):
"""
This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment