Unverified Commit 7d25c9c8 authored by Mohammed Jabir's avatar Mohammed Jabir Committed by GitHub
Browse files

added biogpt token classifier (#22447)



* added biogpt token classifier

* fix reviews

* Updated modeling_biogpt.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 1194c3e3
......@@ -54,3 +54,8 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
[[autodoc]] BioGptForCausalLM
- forward
## BioGptForTokenClassification
[[autodoc]] BioGptForTokenClassification
- forward
\ No newline at end of file
......@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip-->
......
......@@ -1131,6 +1131,7 @@ else:
[
"BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BioGptForCausalLM",
"BioGptForTokenClassification",
"BioGptModel",
"BioGptPreTrainedModel",
]
......@@ -4703,6 +4704,7 @@ if TYPE_CHECKING:
from .models.biogpt import (
BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
BioGptForCausalLM,
BioGptForTokenClassification,
BioGptModel,
BioGptPreTrainedModel,
)
......
......@@ -782,6 +782,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("albert", "AlbertForTokenClassification"),
("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("biogpt", "BioGptForTokenClassification"),
("bloom", "BloomForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("canine", "CanineForTokenClassification"),
......
......@@ -30,6 +30,7 @@ else:
_import_structure["modeling_biogpt"] = [
"BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BioGptForCausalLM",
"BioGptForTokenClassification",
"BioGptModel",
"BioGptPreTrainedModel",
]
......@@ -48,6 +49,7 @@ if TYPE_CHECKING:
from .modeling_biogpt import (
BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
BioGptForCausalLM,
BioGptForTokenClassification,
BioGptModel,
BioGptPreTrainedModel,
)
......
......@@ -25,7 +25,11 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_biogpt import BioGptConfig
......@@ -736,3 +740,95 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
@add_start_docstrings(
"""
BioGPT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BIOGPT_START_DOCSTRING,
)
class BioGptForTokenClassification(BioGptPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.biogpt = BioGptModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
else:
classifier_dropout = config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(BIOGPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.biogpt(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
......@@ -1093,6 +1093,13 @@ class BioGptForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class BioGptForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BioGptModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -29,7 +29,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import BioGptForCausalLM, BioGptModel, BioGptTokenizer
from transformers import BioGptForCausalLM, BioGptForTokenClassification, BioGptModel, BioGptTokenizer
from transformers.models.biogpt.modeling_biogpt import BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST
......@@ -247,6 +247,16 @@ class BioGptModelTester:
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
def create_and_check_biogpt_for_token_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
):
config.num_labels = self.num_labels
model = BioGptForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -264,10 +274,16 @@ class BioGptModelTester:
@require_torch
class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (BioGptModel, BioGptForCausalLM) if is_torch_available() else ()
all_model_classes = (BioGptModel, BioGptForCausalLM, BioGptForTokenClassification) if is_torch_available() else ()
all_generative_model_classes = (BioGptForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": BioGptModel, "text-generation": BioGptForCausalLM} if is_torch_available() else {}
{
"feature-extraction": BioGptModel,
"text-generation": BioGptForCausalLM,
"token-classification": BioGptForTokenClassification,
}
if is_torch_available()
else {}
)
test_pruning = False
......@@ -304,6 +320,10 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_biogpt_weight_initialization(*config_and_inputs)
def test_biogpt_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_biogpt_for_token_classification(*config_and_inputs)
@slow
def test_batch_generation(self):
model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment