Unverified Commit f2514413 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Add LlamaForSequenceClassification (#22209)



* Add LlamaForSequenceClassification

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Add docstring

* Add test

* Add input embedding getter and setter

* Remove dead code

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 675d2a5a
...@@ -66,3 +66,8 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr ...@@ -66,3 +66,8 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr
[[autodoc]] LlamaForCausalLM [[autodoc]] LlamaForCausalLM
- forward - forward
## LlamaForSequenceClassification
[[autodoc]] LlamaForSequenceClassification
- forward
\ No newline at end of file
...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit ...@@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) [ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip--> <!--End of the generated tip-->
......
...@@ -1801,11 +1801,7 @@ else: ...@@ -1801,11 +1801,7 @@ else:
] ]
) )
_import_structure["models.llama"].extend( _import_structure["models.llama"].extend(
[ ["LlamaForCausalLM", "LlamaForSequenceClassification", "LlamaModel", "LlamaPreTrainedModel"]
"LlamaForCausalLM",
"LlamaModel",
"LlamaPreTrainedModel",
]
) )
_import_structure["models.longformer"].extend( _import_structure["models.longformer"].extend(
[ [
...@@ -5198,11 +5194,7 @@ if TYPE_CHECKING: ...@@ -5198,11 +5194,7 @@ if TYPE_CHECKING:
LiltModel, LiltModel,
LiltPreTrainedModel, LiltPreTrainedModel,
) )
from .models.llama import ( from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
)
from .models.longformer import ( from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM, LongformerForMaskedLM,
......
...@@ -652,6 +652,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -652,6 +652,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
("led", "LEDForSequenceClassification"), ("led", "LEDForSequenceClassification"),
("lilt", "LiltForSequenceClassification"), ("lilt", "LiltForSequenceClassification"),
("llama", "LlamaForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"), ("luke", "LukeForSequenceClassification"),
("markuplm", "MarkupLMForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"),
......
...@@ -43,6 +43,7 @@ else: ...@@ -43,6 +43,7 @@ else:
"LlamaForCausalLM", "LlamaForCausalLM",
"LlamaModel", "LlamaModel",
"LlamaPreTrainedModel", "LlamaPreTrainedModel",
"LlamaForSequenceClassification",
] ]
...@@ -63,11 +64,7 @@ if TYPE_CHECKING: ...@@ -63,11 +64,7 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
from .modeling_llama import ( from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
)
else: else:
......
...@@ -24,19 +24,12 @@ from typing import List, Optional, Tuple, Union ...@@ -24,19 +24,12 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
add_start_docstrings,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig from .configuration_llama import LlamaConfig
...@@ -357,7 +350,7 @@ LLAMA_START_DOCSTRING = r""" ...@@ -357,7 +350,7 @@ LLAMA_START_DOCSTRING = r"""
@add_start_docstrings( @add_start_docstrings(
"The bare OPT Model outputting raw hidden-states without any specific head on top.", "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING, LLAMA_START_DOCSTRING,
) )
class LlamaPreTrainedModel(PreTrainedModel): class LlamaPreTrainedModel(PreTrainedModel):
...@@ -831,3 +824,122 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -831,3 +824,122 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
@add_start_docstrings(
"""
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -3768,6 +3768,13 @@ class LlamaForCausalLM(metaclass=DummyObject): ...@@ -3768,6 +3768,13 @@ class LlamaForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LlamaForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlamaModel(metaclass=DummyObject): class LlamaModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -27,7 +27,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti ...@@ -27,7 +27,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import LlamaForCausalLM, LlamaModel from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
class LlamaModelTester: class LlamaModelTester:
...@@ -255,14 +255,7 @@ class LlamaModelTester: ...@@ -255,14 +255,7 @@ class LlamaModelTester:
@require_torch @require_torch
class LlamaModelTest(ModelTesterMixin, unittest.TestCase): class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else ()
(
LlamaModel,
LlamaForCausalLM,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
test_headmasking = False test_headmasking = False
...@@ -283,6 +276,46 @@ class LlamaModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -283,6 +276,46 @@ class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_llama_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = LlamaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_llama_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "single_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = LlamaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_llama_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = LlamaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("LLaMA does not support head pruning.") @unittest.skip("LLaMA does not support head pruning.")
def test_head_pruning(self): def test_head_pruning(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment