"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3994fa5bafa56db6581d962d562f3c54fac291df"
Unverified Commit d1370d29 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add DeBERTa head models (#9691)

* Add DebertaForMaskedLM, DebertaForTokenClassification, DebertaForQuestionAnswering

* Add docs and fix quality

* Fix Deberta not having pooler
parent a7b62fec
...@@ -70,8 +70,29 @@ DebertaPreTrainedModel ...@@ -70,8 +70,29 @@ DebertaPreTrainedModel
:members: :members:
DebertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForMaskedLM
:members:
DebertaForSequenceClassification DebertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForSequenceClassification .. autoclass:: transformers.DebertaForSequenceClassification
:members: :members:
DebertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForTokenClassification
:members:
DebertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForQuestionAnswering
:members:
...@@ -477,7 +477,10 @@ if is_torch_available(): ...@@ -477,7 +477,10 @@ if is_torch_available():
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification", "DebertaForSequenceClassification",
"DebertaModel", "DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel", "DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
] ]
) )
_import_structure["models.distilbert"].extend( _import_structure["models.distilbert"].extend(
...@@ -1527,7 +1530,10 @@ if TYPE_CHECKING: ...@@ -1527,7 +1530,10 @@ if TYPE_CHECKING:
) )
from .models.deberta import ( from .models.deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification, DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel, DebertaModel,
DebertaPreTrainedModel, DebertaPreTrainedModel,
) )
......
...@@ -62,7 +62,13 @@ from ..camembert.modeling_camembert import ( ...@@ -62,7 +62,13 @@ from ..camembert.modeling_camembert import (
CamembertModel, CamembertModel,
) )
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel from ..deberta.modeling_deberta import (
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel,
)
from ..distilbert.modeling_distilbert import ( from ..distilbert.modeling_distilbert import (
DistilBertForMaskedLM, DistilBertForMaskedLM,
DistilBertForMultipleChoice, DistilBertForMultipleChoice,
...@@ -378,6 +384,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ...@@ -378,6 +384,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(FunnelConfig, FunnelForMaskedLM), (FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM), (MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM), (TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
] ]
) )
...@@ -426,6 +433,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -426,6 +433,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(FunnelConfig, FunnelForMaskedLM), (FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM), (MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM), (TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
] ]
) )
...@@ -503,6 +511,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( ...@@ -503,6 +511,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(FunnelConfig, FunnelForQuestionAnswering), (FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering), (LxmertConfig, LxmertForQuestionAnswering),
(MPNetConfig, MPNetForQuestionAnswering), (MPNetConfig, MPNetForQuestionAnswering),
(DebertaConfig, DebertaForQuestionAnswering),
] ]
) )
...@@ -533,6 +542,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -533,6 +542,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(FlaubertConfig, FlaubertForTokenClassification), (FlaubertConfig, FlaubertForTokenClassification),
(FunnelConfig, FunnelForTokenClassification), (FunnelConfig, FunnelForTokenClassification),
(MPNetConfig, MPNetForTokenClassification), (MPNetConfig, MPNetForTokenClassification),
(DebertaConfig, DebertaForTokenClassification),
] ]
) )
......
...@@ -31,7 +31,10 @@ if is_torch_available(): ...@@ -31,7 +31,10 @@ if is_torch_available():
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification", "DebertaForSequenceClassification",
"DebertaModel", "DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel", "DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
] ]
...@@ -42,7 +45,10 @@ if TYPE_CHECKING: ...@@ -42,7 +45,10 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_deberta import ( from .modeling_deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification, DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel, DebertaModel,
DebertaPreTrainedModel, DebertaPreTrainedModel,
) )
......
...@@ -24,7 +24,13 @@ from torch.nn import CrossEntropyLoss ...@@ -24,7 +24,13 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_deberta import DebertaConfig from .configuration_deberta import DebertaConfig
...@@ -945,6 +951,135 @@ class DebertaModel(DebertaPreTrainedModel): ...@@ -945,6 +951,135 @@ class DebertaModel(DebertaPreTrainedModel):
) )
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.deberta = DebertaModel(config)
self.cls = DebertaOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/deberta-base",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
class DebertaPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
class DebertaLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = DebertaPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
class DebertaOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = DebertaLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
@add_start_docstrings( @add_start_docstrings(
""" """
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
...@@ -1049,3 +1184,192 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1049,3 +1184,192 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
DEBERTA_START_DOCSTRING,
)
class DebertaForTokenClassification(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.deberta = DebertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/deberta-base",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
DEBERTA_START_DOCSTRING,
)
class DebertaForQuestionAnswering(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.deberta = DebertaModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/deberta-base",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -739,6 +739,24 @@ class CTRLPreTrainedModel: ...@@ -739,6 +739,24 @@ class CTRLPreTrainedModel:
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DebertaForMaskedLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DebertaForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DebertaForSequenceClassification: class DebertaForSequenceClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -748,6 +766,15 @@ class DebertaForSequenceClassification: ...@@ -748,6 +766,15 @@ class DebertaForSequenceClassification:
requires_pytorch(self) requires_pytorch(self)
class DebertaForTokenClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DebertaModel: class DebertaModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment