Unverified Commit 487f132a authored by Ashwin Mathur's avatar Ashwin Mathur Committed by GitHub
Browse files

Add `BioGPTForSequenceClassification` (#22253)



* added BioGptForSequenceClassification

* added source of copied code

* typo

* Format code with black

* Update comments for copied code

* Remove code copy comment

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix failing tests

* Update code copied from comments

* Fix code quality

* Update src/transformers/models/biogpt/modeling_biogpt.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix lint error

* Update src/transformers/models/biogpt/modeling_biogpt.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Rename model to biogpt for consistency

* Add PipelineTesterMixin to test_modeling_biogpt.py

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Resolve merge confict

---------
Co-authored-by: default avatarGuillem García Subies <37592763+GuillemGSubies@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 549e5f9f
...@@ -55,7 +55,14 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The ...@@ -55,7 +55,14 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
[[autodoc]] BioGptForCausalLM [[autodoc]] BioGptForCausalLM
- forward - forward
## BioGptForTokenClassification ## BioGptForTokenClassification
[[autodoc]] BioGptForTokenClassification [[autodoc]] BioGptForTokenClassification
- forward - forward
## BioGptForSequenceClassification
[[autodoc]] BioGptForSequenceClassification
- forward
\ No newline at end of file
...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit ...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) [ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip--> <!--End of the generated tip-->
......
...@@ -1147,6 +1147,7 @@ else: ...@@ -1147,6 +1147,7 @@ else:
[ [
"BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST", "BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BioGptForCausalLM", "BioGptForCausalLM",
"BioGptForSequenceClassification",
"BioGptForTokenClassification", "BioGptForTokenClassification",
"BioGptModel", "BioGptModel",
"BioGptPreTrainedModel", "BioGptPreTrainedModel",
...@@ -4792,6 +4793,7 @@ if TYPE_CHECKING: ...@@ -4792,6 +4793,7 @@ if TYPE_CHECKING:
from .models.biogpt import ( from .models.biogpt import (
BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST, BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
BioGptForCausalLM, BioGptForCausalLM,
BioGptForSequenceClassification,
BioGptForTokenClassification, BioGptForTokenClassification,
BioGptModel, BioGptModel,
BioGptPreTrainedModel, BioGptPreTrainedModel,
......
...@@ -648,6 +648,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -648,6 +648,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bert", "BertForSequenceClassification"), ("bert", "BertForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"), ("big_bird", "BigBirdForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("biogpt", "BioGptForSequenceClassification"),
("bloom", "BloomForSequenceClassification"), ("bloom", "BloomForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"), ("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"), ("canine", "CanineForSequenceClassification"),
......
...@@ -31,6 +31,7 @@ else: ...@@ -31,6 +31,7 @@ else:
"BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST", "BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BioGptForCausalLM", "BioGptForCausalLM",
"BioGptForTokenClassification", "BioGptForTokenClassification",
"BioGptForSequenceClassification",
"BioGptModel", "BioGptModel",
"BioGptPreTrainedModel", "BioGptPreTrainedModel",
] ]
...@@ -49,6 +50,7 @@ if TYPE_CHECKING: ...@@ -49,6 +50,7 @@ if TYPE_CHECKING:
from .modeling_biogpt import ( from .modeling_biogpt import (
BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST, BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
BioGptForCausalLM, BioGptForCausalLM,
BioGptForSequenceClassification,
BioGptForTokenClassification, BioGptForTokenClassification,
BioGptModel, BioGptModel,
BioGptPreTrainedModel, BioGptPreTrainedModel,
......
...@@ -22,16 +22,22 @@ from typing import Optional, Tuple, Union ...@@ -22,16 +22,22 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_biogpt import BioGptConfig from .configuration_biogpt import BioGptConfig
...@@ -40,8 +46,10 @@ logger = logging.get_logger(__name__) ...@@ -40,8 +46,10 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/biogpt" _CHECKPOINT_FOR_DOC = "microsoft/biogpt"
_CONFIG_FOR_DOC = "BioGptConfig" _CONFIG_FOR_DOC = "BioGptConfig"
BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/biogpt", "microsoft/biogpt",
"microsoft/BioGPT-Large",
# See all BioGPT models at https://huggingface.co/models?filter=biogpt # See all BioGPT models at https://huggingface.co/models?filter=biogpt
] ]
...@@ -832,3 +840,129 @@ class BioGptForTokenClassification(BioGptPreTrainedModel): ...@@ -832,3 +840,129 @@ class BioGptForTokenClassification(BioGptPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The BioGpt Model transformer with a sequence classification head on top (linear layer).
[`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it is required to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BIOGPT_START_DOCSTRING,
)
class BioGptForSequenceClassification(BioGptPreTrainedModel):
def __init__(self, config: BioGptConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.biogpt = BioGptModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIOGPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.biogpt(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
if self.config.pad_token_id is None:
sequence_length = -1
else:
if input_ids is not None:
sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_length = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def get_input_embeddings(self):
return self.biogpt.embed_tokens
def set_input_embeddings(self, value):
self.biogpt.embed_tokens = value
...@@ -1103,6 +1103,13 @@ class BioGptForCausalLM(metaclass=DummyObject): ...@@ -1103,6 +1103,13 @@ class BioGptForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BioGptForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BioGptForTokenClassification(metaclass=DummyObject): class BioGptForTokenClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -29,7 +29,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -29,7 +29,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import BioGptForCausalLM, BioGptForTokenClassification, BioGptModel, BioGptTokenizer from transformers import (
BioGptForCausalLM,
BioGptForSequenceClassification,
BioGptForTokenClassification,
BioGptModel,
BioGptTokenizer,
)
from transformers.models.biogpt.modeling_biogpt import BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.biogpt.modeling_biogpt import BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -274,13 +280,18 @@ class BioGptModelTester: ...@@ -274,13 +280,18 @@ class BioGptModelTester:
@require_torch @require_torch
class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (BioGptModel, BioGptForCausalLM, BioGptForTokenClassification) if is_torch_available() else () all_model_classes = (
(BioGptModel, BioGptForCausalLM, BioGptForSequenceClassification, BioGptForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (BioGptForCausalLM,) if is_torch_available() else () all_generative_model_classes = (BioGptForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": BioGptModel, "feature-extraction": BioGptModel,
"text-generation": BioGptForCausalLM, "text-generation": BioGptForCausalLM,
"token-classification": BioGptForTokenClassification, "token-classification": BioGptForTokenClassification,
"text-classification": BioGptForSequenceClassification,
} }
if is_torch_available() if is_torch_available()
else {} else {}
...@@ -374,6 +385,35 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -374,6 +385,35 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
model = BioGptModel.from_pretrained(model_name) model = BioGptModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
# Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common
def test_biogpt_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = BioGptForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common
def test_biogpt_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = BioGptForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@require_torch @require_torch
class BioGptModelIntegrationTest(unittest.TestCase): class BioGptModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment