Unverified Commit d1370d29 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add DeBERTa head models (#9691)

* Add DebertaForMaskedLM, DebertaForTokenClassification, DebertaForQuestionAnswering

* Add docs and fix quality

* Fix Deberta not having pooler
parent a7b62fec
...@@ -70,8 +70,29 @@ DebertaPreTrainedModel ...@@ -70,8 +70,29 @@ DebertaPreTrainedModel
:members: :members:
DebertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForMaskedLM
:members:
DebertaForSequenceClassification DebertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForSequenceClassification .. autoclass:: transformers.DebertaForSequenceClassification
:members: :members:
DebertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForTokenClassification
:members:
DebertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DebertaForQuestionAnswering
:members:
...@@ -477,7 +477,10 @@ if is_torch_available(): ...@@ -477,7 +477,10 @@ if is_torch_available():
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification", "DebertaForSequenceClassification",
"DebertaModel", "DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel", "DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
] ]
) )
_import_structure["models.distilbert"].extend( _import_structure["models.distilbert"].extend(
...@@ -1527,7 +1530,10 @@ if TYPE_CHECKING: ...@@ -1527,7 +1530,10 @@ if TYPE_CHECKING:
) )
from .models.deberta import ( from .models.deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification, DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel, DebertaModel,
DebertaPreTrainedModel, DebertaPreTrainedModel,
) )
......
...@@ -62,7 +62,13 @@ from ..camembert.modeling_camembert import ( ...@@ -62,7 +62,13 @@ from ..camembert.modeling_camembert import (
CamembertModel, CamembertModel,
) )
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel from ..deberta.modeling_deberta import (
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel,
)
from ..distilbert.modeling_distilbert import ( from ..distilbert.modeling_distilbert import (
DistilBertForMaskedLM, DistilBertForMaskedLM,
DistilBertForMultipleChoice, DistilBertForMultipleChoice,
...@@ -378,6 +384,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ...@@ -378,6 +384,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(FunnelConfig, FunnelForMaskedLM), (FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM), (MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM), (TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
] ]
) )
...@@ -426,6 +433,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -426,6 +433,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(FunnelConfig, FunnelForMaskedLM), (FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM), (MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM), (TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
] ]
) )
...@@ -503,6 +511,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( ...@@ -503,6 +511,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(FunnelConfig, FunnelForQuestionAnswering), (FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering), (LxmertConfig, LxmertForQuestionAnswering),
(MPNetConfig, MPNetForQuestionAnswering), (MPNetConfig, MPNetForQuestionAnswering),
(DebertaConfig, DebertaForQuestionAnswering),
] ]
) )
...@@ -533,6 +542,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -533,6 +542,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(FlaubertConfig, FlaubertForTokenClassification), (FlaubertConfig, FlaubertForTokenClassification),
(FunnelConfig, FunnelForTokenClassification), (FunnelConfig, FunnelForTokenClassification),
(MPNetConfig, MPNetForTokenClassification), (MPNetConfig, MPNetForTokenClassification),
(DebertaConfig, DebertaForTokenClassification),
] ]
) )
......
...@@ -31,7 +31,10 @@ if is_torch_available(): ...@@ -31,7 +31,10 @@ if is_torch_available():
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification", "DebertaForSequenceClassification",
"DebertaModel", "DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel", "DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
] ]
...@@ -42,7 +45,10 @@ if TYPE_CHECKING: ...@@ -42,7 +45,10 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_deberta import ( from .modeling_deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification, DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel, DebertaModel,
DebertaPreTrainedModel, DebertaPreTrainedModel,
) )
......
...@@ -24,7 +24,13 @@ from torch.nn import CrossEntropyLoss ...@@ -24,7 +24,13 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_deberta import DebertaConfig from .configuration_deberta import DebertaConfig
...@@ -945,6 +951,135 @@ class DebertaModel(DebertaPreTrainedModel): ...@@ -945,6 +951,135 @@ class DebertaModel(DebertaPreTrainedModel):
) )
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.deberta = DebertaModel(config)
self.cls = DebertaOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/deberta-base",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
class DebertaPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
class DebertaLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = DebertaPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
class DebertaOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = DebertaLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
@add_start_docstrings( @add_start_docstrings(
""" """
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
...@@ -1049,3 +1184,192 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1049,3 +1184,192 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
DEBERTA_START_DOCSTRING,
)
class DebertaForTokenClassification(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.deberta = DebertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/deberta-base",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
DEBERTA_START_DOCSTRING,
)
class DebertaForQuestionAnswering(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.deberta = DebertaModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/deberta-base",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -739,6 +739,24 @@ class CTRLPreTrainedModel: ...@@ -739,6 +739,24 @@ class CTRLPreTrainedModel:
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DebertaForMaskedLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DebertaForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DebertaForSequenceClassification: class DebertaForSequenceClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -748,6 +766,15 @@ class DebertaForSequenceClassification: ...@@ -748,6 +766,15 @@ class DebertaForSequenceClassification:
requires_pytorch(self) requires_pytorch(self)
class DebertaForTokenClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DebertaModel: class DebertaModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
# coding=utf-8 # coding=utf-8
# Copyright 2018 Microsoft Authors and the HuggingFace Inc. team. # Copyright 2018 Microsoft Authors and the HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import unittest import unittest
import numpy as np import numpy as np
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ( # XxxForMaskedLM,; XxxForQuestionAnswering,; XxxForTokenClassification, from transformers import (
DebertaConfig, DebertaConfig,
DebertaForSequenceClassification, DebertaForMaskedLM,
DebertaModel, DebertaForQuestionAnswering,
) DebertaForSequenceClassification,
from transformers.models.deberta.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST DebertaForTokenClassification,
DebertaModel,
)
@require_torch from transformers.models.deberta.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( @require_torch
( class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
DebertaModel,
DebertaForSequenceClassification, all_model_classes = (
) # , DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForTokenClassification) (
if is_torch_available() DebertaModel,
else () DebertaForMaskedLM,
) DebertaForSequenceClassification,
DebertaForTokenClassification,
test_torchscript = False DebertaForQuestionAnswering,
test_pruning = False )
test_head_masking = False if is_torch_available()
is_encoder_decoder = False else ()
)
class DebertaModelTester(object):
def __init__( test_torchscript = False
self, test_pruning = False
parent, test_head_masking = False
batch_size=13, is_encoder_decoder = False
seq_length=7,
is_training=True, class DebertaModelTester(object):
use_input_mask=True, def __init__(
use_token_type_ids=True, self,
use_labels=True, parent,
vocab_size=99, batch_size=13,
hidden_size=32, seq_length=7,
num_hidden_layers=5, is_training=True,
num_attention_heads=4, use_input_mask=True,
intermediate_size=37, use_token_type_ids=True,
hidden_act="gelu", use_labels=True,
hidden_dropout_prob=0.1, vocab_size=99,
attention_probs_dropout_prob=0.1, hidden_size=32,
max_position_embeddings=512, num_hidden_layers=5,
type_vocab_size=16, num_attention_heads=4,
type_sequence_label_size=2, intermediate_size=37,
initializer_range=0.02, hidden_act="gelu",
relative_attention=False, hidden_dropout_prob=0.1,
position_biased_input=True, attention_probs_dropout_prob=0.1,
pos_att_type="None", max_position_embeddings=512,
num_labels=3, type_vocab_size=16,
num_choices=4, type_sequence_label_size=2,
scope=None, initializer_range=0.02,
): relative_attention=False,
self.parent = parent position_biased_input=True,
self.batch_size = batch_size pos_att_type="None",
self.seq_length = seq_length num_labels=3,
self.is_training = is_training num_choices=4,
self.use_input_mask = use_input_mask scope=None,
self.use_token_type_ids = use_token_type_ids ):
self.use_labels = use_labels self.parent = parent
self.vocab_size = vocab_size self.batch_size = batch_size
self.hidden_size = hidden_size self.seq_length = seq_length
self.num_hidden_layers = num_hidden_layers self.is_training = is_training
self.num_attention_heads = num_attention_heads self.use_input_mask = use_input_mask
self.intermediate_size = intermediate_size self.use_token_type_ids = use_token_type_ids
self.hidden_act = hidden_act self.use_labels = use_labels
self.hidden_dropout_prob = hidden_dropout_prob self.vocab_size = vocab_size
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings self.num_hidden_layers = num_hidden_layers
self.type_vocab_size = type_vocab_size self.num_attention_heads = num_attention_heads
self.type_sequence_label_size = type_sequence_label_size self.intermediate_size = intermediate_size
self.initializer_range = initializer_range self.hidden_act = hidden_act
self.num_labels = num_labels self.hidden_dropout_prob = hidden_dropout_prob
self.num_choices = num_choices self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.relative_attention = relative_attention self.max_position_embeddings = max_position_embeddings
self.position_biased_input = position_biased_input self.type_vocab_size = type_vocab_size
self.pos_att_type = pos_att_type self.type_sequence_label_size = type_sequence_label_size
self.scope = scope self.initializer_range = initializer_range
self.num_labels = num_labels
def prepare_config_and_inputs(self): self.num_choices = num_choices
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) self.relative_attention = relative_attention
self.position_biased_input = position_biased_input
input_mask = None self.pos_att_type = pos_att_type
if self.use_input_mask: self.scope = scope
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
def prepare_config_and_inputs(self):
token_type_ids = None input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) input_mask = None
if self.use_input_mask:
sequence_labels = None input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_labels = None
choice_labels = None token_type_ids = None
if self.use_labels: if self.use_token_type_ids:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) sequence_labels = None
token_labels = None
config = DebertaConfig( choice_labels = None
vocab_size=self.vocab_size, if self.use_labels:
hidden_size=self.hidden_size, sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
num_hidden_layers=self.num_hidden_layers, token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
num_attention_heads=self.num_attention_heads, choice_labels = ids_tensor([self.batch_size], self.num_choices)
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act, config = DebertaConfig(
hidden_dropout_prob=self.hidden_dropout_prob, vocab_size=self.vocab_size,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, hidden_size=self.hidden_size,
max_position_embeddings=self.max_position_embeddings, num_hidden_layers=self.num_hidden_layers,
type_vocab_size=self.type_vocab_size, num_attention_heads=self.num_attention_heads,
initializer_range=self.initializer_range, intermediate_size=self.intermediate_size,
relative_attention=self.relative_attention, hidden_act=self.hidden_act,
position_biased_input=self.position_biased_input, hidden_dropout_prob=self.hidden_dropout_prob,
pos_att_type=self.pos_att_type, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
) max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels initializer_range=self.initializer_range,
relative_attention=self.relative_attention,
def check_loss_output(self, result): position_biased_input=self.position_biased_input,
self.parent.assertListEqual(list(result.loss.size()), []) pos_att_type=self.pos_att_type,
)
def create_and_check_deberta_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DebertaModel(config=config) def check_loss_output(self, result):
model.to(torch_device) self.parent.assertListEqual(list(result.loss.size()), [])
model.eval()
sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] def create_and_check_deberta_model(
sequence_output = model(input_ids, token_type_ids=token_type_ids)[0] self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
sequence_output = model(input_ids)[0] ):
model = DebertaModel(config=config)
self.parent.assertListEqual( model.to(torch_device)
list(sequence_output.size()), [self.batch_size, self.seq_length, self.hidden_size] model.eval()
) sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
sequence_output = model(input_ids, token_type_ids=token_type_ids)[0]
def create_and_check_deberta_for_sequence_classification( sequence_output = model(input_ids)[0]
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): self.parent.assertListEqual(
config.num_labels = self.num_labels list(sequence_output.size()), [self.batch_size, self.seq_length, self.hidden_size]
model = DebertaForSequenceClassification(config) )
model.to(torch_device)
model.eval() def create_and_check_deberta_for_masked_lm(
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self.parent.assertListEqual(list(result.logits.size()), [self.batch_size, self.num_labels]) ):
self.check_loss_output(result) model = DebertaForMaskedLM(config=config)
model.to(torch_device)
def prepare_config_and_inputs_for_common(self): model.eval()
config_and_inputs = self.prepare_config_and_inputs() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
(
config, self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
input_ids,
token_type_ids, def create_and_check_deberta_for_sequence_classification(
input_mask, self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
sequence_labels, ):
token_labels, config.num_labels = self.num_labels
choice_labels, model = DebertaForSequenceClassification(config)
) = config_and_inputs model.to(torch_device)
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} model.eval()
return config, inputs_dict result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result.logits.size()), [self.batch_size, self.num_labels])
def setUp(self): self.check_loss_output(result)
self.model_tester = DebertaModelTest.DebertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37) def create_and_check_deberta_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def test_config(self): ):
self.config_tester.run_common_tests() config.num_labels = self.num_labels
model = DebertaForTokenClassification(config=config)
def test_deberta_model(self): model.to(torch_device)
config_and_inputs = self.model_tester.prepare_config_and_inputs() model.eval()
self.model_tester.create_and_check_deberta_model(*config_and_inputs) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() def create_and_check_deberta_for_question_answering(
self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs) self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@unittest.skip(reason="Model not available yet") model = DebertaForQuestionAnswering(config=config)
def test_for_masked_lm(self): model.to(torch_device)
config_and_inputs = self.model_tester.prepare_config_and_inputs() model.eval()
self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs) result = model(
input_ids,
@unittest.skip(reason="Model not available yet") attention_mask=input_mask,
def test_for_question_answering(self): token_type_ids=token_type_ids,
config_and_inputs = self.model_tester.prepare_config_and_inputs() start_positions=sequence_labels,
self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs) end_positions=sequence_labels,
)
@unittest.skip(reason="Model not available yet") self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
def test_for_token_classification(self): self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs) def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@slow (
def test_model_from_pretrained(self): config,
for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: input_ids,
model = DebertaModel.from_pretrained(model_name) token_type_ids,
self.assertIsNotNone(model) input_mask,
sequence_labels,
token_labels,
@require_torch choice_labels,
@require_sentencepiece ) = config_and_inputs
@require_tokenizers inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
class DebertaModelIntegrationTest(unittest.TestCase): return config, inputs_dict
@unittest.skip(reason="Model not available yet")
def test_inference_masked_lm(self): def setUp(self):
pass self.model_tester = DebertaModelTest.DebertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37)
@slow
def test_inference_no_head(self): def test_config(self):
random.seed(0) self.config_tester.run_common_tests()
np.random.seed(0)
torch.manual_seed(0) def test_deberta_model(self):
torch.cuda.manual_seed_all(0) config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = DebertaModel.from_pretrained("microsoft/deberta-base") self.model_tester.create_and_check_deberta_model(*config_and_inputs)
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) def test_for_sequence_classification(self):
output = model(input_ids)[0] config_and_inputs = self.model_tester.prepare_config_and_inputs()
# compare the actual values for a slice. self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs)
expected_slice = torch.tensor(
[[[-0.0218, -0.6641, -0.3665], [-0.3907, -0.4716, -0.6640], [0.7461, 1.2570, -0.9063]]] def test_for_masked_lm(self):
) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}") self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = DebertaModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
@require_sentencepiece
@require_tokenizers
class DebertaModelIntegrationTest(unittest.TestCase):
@unittest.skip(reason="Model not available yet")
def test_inference_masked_lm(self):
pass
@slow
def test_inference_no_head(self):
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
model = DebertaModel.from_pretrained("microsoft/deberta-base")
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[-0.0218, -0.6641, -0.3665], [-0.3907, -0.4716, -0.6640], [0.7461, 1.2570, -0.9063]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment