Unverified Commit edb672ac authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Add `BloomForSequenceClassification` and `BloomForTokenClassification` classes (#17639)



* add new bloom classes

* (feat) add bloom classification tests; make style

* style: change import in test

* add some typehints to bloom classes

* merge main into branch

* fix: input checking in bloom seq classification

* fix tests

* change model class tests

* fix few tests

- more tests should pass
- one test left

* make token classifier return hidden states

* style: make BLOOM typehints consistent
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent bd43151a
...@@ -45,3 +45,13 @@ Several smaller versions of the models have been trained on the same dataset. BL ...@@ -45,3 +45,13 @@ Several smaller versions of the models have been trained on the same dataset. BL
[[autodoc]] BloomForCausalLM [[autodoc]] BloomForCausalLM
- forward - forward
## BloomForSequenceClassification
[[autodoc]] BloomForSequenceClassification
- forward
## BloomForTokenClassification
[[autodoc]] BloomForTokenClassification
- forward
\ No newline at end of file
...@@ -870,6 +870,8 @@ else: ...@@ -870,6 +870,8 @@ else:
"BloomForCausalLM", "BloomForCausalLM",
"BloomModel", "BloomModel",
"BloomPreTrainedModel", "BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
] ]
) )
_import_structure["models.blenderbot"].extend( _import_structure["models.blenderbot"].extend(
...@@ -3417,6 +3419,8 @@ if TYPE_CHECKING: ...@@ -3417,6 +3419,8 @@ if TYPE_CHECKING:
from .models.bloom import ( from .models.bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM, BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel, BloomModel,
BloomPreTrainedModel, BloomPreTrainedModel,
) )
......
...@@ -453,6 +453,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -453,6 +453,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bert", "BertForSequenceClassification"), ("bert", "BertForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"), ("big_bird", "BigBirdForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("bloom", "BloomForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"), ("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"), ("canine", "CanineForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"), ("convbert", "ConvBertForSequenceClassification"),
...@@ -563,6 +564,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -563,6 +564,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("albert", "AlbertForTokenClassification"), ("albert", "AlbertForTokenClassification"),
("bert", "BertForTokenClassification"), ("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"), ("big_bird", "BigBirdForTokenClassification"),
("bloom", "BloomForTokenClassification"),
("camembert", "CamembertForTokenClassification"), ("camembert", "CamembertForTokenClassification"),
("canine", "CanineForTokenClassification"), ("canine", "CanineForTokenClassification"),
("convbert", "ConvBertForTokenClassification"), ("convbert", "ConvBertForTokenClassification"),
......
...@@ -46,6 +46,8 @@ else: ...@@ -46,6 +46,8 @@ else:
"BloomForCausalLM", "BloomForCausalLM",
"BloomModel", "BloomModel",
"BloomPreTrainedModel", "BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -68,6 +70,8 @@ if TYPE_CHECKING: ...@@ -68,6 +70,8 @@ if TYPE_CHECKING:
from .modeling_bloom import ( from .modeling_bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM, BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel, BloomModel,
BloomPreTrainedModel, BloomPreTrainedModel,
) )
......
...@@ -15,15 +15,20 @@ ...@@ -15,15 +15,20 @@
"""PyTorch BLOOM model.""" """PyTorch BLOOM model."""
import math import math
from typing import Tuple from typing import Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_bloom import BloomConfig from .configuration_bloom import BloomConfig
...@@ -42,7 +47,7 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -42,7 +47,7 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bigscience/bloom-1b3", "bigscience/bloom-1b3",
"bigscience/bloom-2b5", "bigscience/bloom-2b5",
"bigscience/bloom-6b3", "bigscience/bloom-6b3",
"bigscience/bloom-176b", "bigscience/bloom",
] ]
...@@ -726,7 +731,7 @@ class BloomModel(BloomPreTrainedModel): ...@@ -726,7 +731,7 @@ class BloomModel(BloomPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -902,7 +907,7 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -902,7 +907,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -959,3 +964,223 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -959,3 +964,223 @@ class BloomForCausalLM(BloomPreTrainedModel):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past
) )
@add_start_docstrings(
"""
The Bloom Model transformer with a sequence classification head on top (linear layer).
[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BLOOM_START_DOCSTRING,
)
class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
...@@ -558,6 +558,18 @@ class Trainer: ...@@ -558,6 +558,18 @@ class Trainer:
) )
self.use_apex = True self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if (
is_sagemaker_mp_enabled()
and self.use_cuda_amp
and args.max_grad_norm is not None
and args.max_grad_norm > 0
):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing # Label smoothing
if self.args.label_smoothing_factor != 0: if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
......
...@@ -986,6 +986,20 @@ class BloomForCausalLM(metaclass=DummyObject): ...@@ -986,6 +986,20 @@ class BloomForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BloomForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BloomForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BloomModel(metaclass=DummyObject): class BloomModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -28,7 +28,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti ...@@ -28,7 +28,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BloomForCausalLM, BloomModel, BloomTokenizerFast from transformers import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomTokenizerFast,
)
@require_torch @require_torch
...@@ -96,9 +103,13 @@ class BloomModelTester: ...@@ -96,9 +103,13 @@ class BloomModelTester:
if self.use_input_mask: if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length]) input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config(gradient_checkpointing=gradient_checkpointing) config = self.get_config(gradient_checkpointing=gradient_checkpointing)
return (config, input_ids, input_mask) return (config, input_ids, input_mask, sequence_labels)
def get_config(self, gradient_checkpointing=False, slow_but_exact=True): def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
return BloomConfig( return BloomConfig(
...@@ -116,6 +127,7 @@ class BloomModelTester: ...@@ -116,6 +127,7 @@ class BloomModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
num_labels=self.num_labels,
gradient_checkpointing=gradient_checkpointing, gradient_checkpointing=gradient_checkpointing,
slow_but_exact=slow_but_exact, slow_but_exact=slow_but_exact,
dtype="float32", dtype="float32",
...@@ -245,6 +257,23 @@ class BloomModelTester: ...@@ -245,6 +257,23 @@ class BloomModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
config.num_labels = self.num_labels
model = BloomForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
model = BloomForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards( def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, *args, gradient_checkpointing=False self, config, input_ids, input_mask, *args, gradient_checkpointing=False
): ):
...@@ -269,7 +298,7 @@ class BloomModelTester: ...@@ -269,7 +298,7 @@ class BloomModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs config, input_ids, input_mask, sequence_labels = config_and_inputs
inputs_dict = {"input_ids": input_ids} inputs_dict = {"input_ids": input_ids}
...@@ -279,7 +308,17 @@ class BloomModelTester: ...@@ -279,7 +308,17 @@ class BloomModelTester:
@require_torch @require_torch
class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BloomModel, BloomForCausalLM) if is_torch_available() else () all_model_classes = (
(
BloomModel,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else () all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_missing_keys = False test_missing_keys = False
...@@ -313,6 +352,14 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -313,6 +352,14 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_bloom_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
def test_bloom_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
def test_bloom_gradient_checkpointing(self): def test_bloom_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment