Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c749a543
Unverified
Commit
c749a543
authored
Feb 21, 2020
by
maximeilluin
Committed by
GitHub
Feb 21, 2020
Browse files
Added CamembertForQuestionAnswering (#2746)
* Added CamembertForQuestionAnswering * fixed camembert tokenizer case
parent
5211d333
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
7 deletions
+30
-7
examples/run_squad.py
examples/run_squad.py
+10
-3
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/data/processors/squad.py
src/transformers/data/processors/squad.py
+1
-1
src/transformers/modeling_camembert.py
src/transformers/modeling_camembert.py
+18
-3
No files found.
examples/run_squad.py
View file @
c749a543
...
...
@@ -38,6 +38,9 @@ from transformers import (
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
,
CamembertConfig
,
CamembertForQuestionAnswering
,
CamembertTokenizer
,
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
,
...
...
@@ -70,12 +73,16 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
(
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
RobertaConfig
,
XLNetConfig
,
XLMConfig
)),
(
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
CamembertConfig
,
RobertaConfig
,
XLNetConfig
,
XLMConfig
)
),
(),
)
MODEL_CLASSES
=
{
"bert"
:
(
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
),
"camembert"
:
(
CamembertConfig
,
CamembertForQuestionAnswering
,
CamembertTokenizer
),
"roberta"
:
(
RobertaConfig
,
RobertaForQuestionAnswering
,
RobertaTokenizer
),
"xlnet"
:
(
XLNetConfig
,
XLNetForQuestionAnswering
,
XLNetTokenizer
),
"xlm"
:
(
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
),
...
...
@@ -212,7 +219,7 @@ def train(args, train_dataset, model, tokenizer):
"end_positions"
:
batch
[
4
],
}
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]:
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
,
"camembert"
]:
del
inputs
[
"token_type_ids"
]
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
...
...
@@ -327,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"token_type_ids"
:
batch
[
2
],
}
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
]:
if
args
.
model_type
in
[
"xlm"
,
"roberta"
,
"distilbert"
,
"camembert"
]:
del
inputs
[
"token_type_ids"
]
example_indices
=
batch
[
3
]
...
...
src/transformers/__init__.py
View file @
c749a543
...
...
@@ -221,6 +221,7 @@ if is_torch_available():
CamembertModel
,
CamembertForSequenceClassification
,
CamembertForTokenClassification
,
CamembertForQuestionAnswering
,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
from
.modeling_distilbert
import
(
...
...
src/transformers/data/processors/squad.py
View file @
c749a543
...
...
@@ -123,7 +123,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
(
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
+
1
if
"roberta"
in
str
(
type
(
tokenizer
))
if
"roberta"
in
str
(
type
(
tokenizer
))
or
"camembert"
in
str
(
type
(
tokenizer
))
else
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
)
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
...
...
src/transformers/modeling_camembert.py
View file @
c749a543
...
...
@@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch CamemBERT model. """
import
logging
from
.configuration_camembert
import
CamembertConfig
...
...
@@ -23,6 +22,7 @@ from .file_utils import add_start_docstrings
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaForMultipleChoice
,
RobertaForQuestionAnswering
,
RobertaForSequenceClassification
,
RobertaForTokenClassification
,
RobertaModel
,
...
...
@@ -37,7 +37,6 @@ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"umberto-wikipedia-uncased-v1"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin"
,
}
CAMEMBERT_START_DOCSTRING
=
r
"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
...
...
@@ -46,7 +45,8 @@ CAMEMBERT_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
model. Initializing with a config file does not load the weights associated with the model, only the
configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
...
...
@@ -121,3 +121,18 @@ class CamembertForTokenClassification(RobertaForTokenClassification):
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@
add_start_docstrings
(
"""CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits` """
,
CAMEMBERT_START_DOCSTRING
,
)
class
CamembertForQuestionAnswering
(
RobertaForQuestionAnswering
):
"""
This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class
=
CamembertConfig
pretrained_model_archive_map
=
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment