Unverified Commit 2b096508 authored by gilad19's avatar gilad19 Committed by GitHub
Browse files

Add ViltForTokenClassification e.g. for Named-Entity-Recognition (NER) (#17924)



* Add ViltForTokenClassification e.g. for Named-Entity-Recognition (NER)

* Add ViltForTokenClassification e.g. for Named-Entity-Recognition (NER)

* provide classifier only text hidden states

* add test_for_token_classification

* Update src/transformers/models/vilt/modeling_vilt.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/vilt/modeling_vilt.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/vilt/modeling_vilt.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/vilt/modeling_vilt.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* add test_for_token_classification
Co-authored-by: default avatargfuchs <gfuchs@ebay.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 002915aa
...@@ -87,3 +87,8 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ...@@ -87,3 +87,8 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] ViltForImageAndTextRetrieval [[autodoc]] ViltForImageAndTextRetrieval
- forward - forward
## ViltForTokenClassification
[[autodoc]] ViltForTokenClassification
- forward
...@@ -1816,6 +1816,7 @@ else: ...@@ -1816,6 +1816,7 @@ else:
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST", "VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViltForImageAndTextRetrieval", "ViltForImageAndTextRetrieval",
"ViltForImagesAndTextClassification", "ViltForImagesAndTextClassification",
"ViltForTokenClassification",
"ViltForMaskedLM", "ViltForMaskedLM",
"ViltForQuestionAnswering", "ViltForQuestionAnswering",
"ViltLayer", "ViltLayer",
...@@ -4317,6 +4318,7 @@ if TYPE_CHECKING: ...@@ -4317,6 +4318,7 @@ if TYPE_CHECKING:
ViltForImagesAndTextClassification, ViltForImagesAndTextClassification,
ViltForMaskedLM, ViltForMaskedLM,
ViltForQuestionAnswering, ViltForQuestionAnswering,
ViltForTokenClassification,
ViltLayer, ViltLayer,
ViltModel, ViltModel,
ViltPreTrainedModel, ViltPreTrainedModel,
......
...@@ -42,6 +42,7 @@ else: ...@@ -42,6 +42,7 @@ else:
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST", "VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViltForImageAndTextRetrieval", "ViltForImageAndTextRetrieval",
"ViltForImagesAndTextClassification", "ViltForImagesAndTextClassification",
"ViltForTokenClassification",
"ViltForMaskedLM", "ViltForMaskedLM",
"ViltForQuestionAnswering", "ViltForQuestionAnswering",
"ViltLayer", "ViltLayer",
...@@ -74,6 +75,7 @@ if TYPE_CHECKING: ...@@ -74,6 +75,7 @@ if TYPE_CHECKING:
ViltForImagesAndTextClassification, ViltForImagesAndTextClassification,
ViltForMaskedLM, ViltForMaskedLM,
ViltForQuestionAnswering, ViltForQuestionAnswering,
ViltForTokenClassification,
ViltLayer, ViltLayer,
ViltModel, ViltModel,
ViltPreTrainedModel, ViltPreTrainedModel,
......
...@@ -32,6 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,6 +32,7 @@ from ...modeling_outputs import (
MaskedLMOutput, MaskedLMOutput,
ModelOutput, ModelOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
...@@ -1402,3 +1403,90 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel): ...@@ -1402,3 +1403,90 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
hidden_states=hidden_states, hidden_states=hidden_states,
attentions=attentions, attentions=attentions,
) )
@add_start_docstrings(
"""
ViLT Model with a token classification head on top (a linear layer on top of the final hidden-states of the text
tokens) e.g. for Named-Entity-Recognition (NER) tasks.
""",
VILT_START_DOCSTRING,
)
class ViltForTokenClassification(ViltPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.vilt = ViltModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vilt(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
text_input_size = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output[:, :text_input_size])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -4725,6 +4725,13 @@ class ViltForQuestionAnswering(metaclass=DummyObject): ...@@ -4725,6 +4725,13 @@ class ViltForQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ViltForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ViltLayer(metaclass=DummyObject): class ViltLayer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -37,6 +37,7 @@ if is_torch_available(): ...@@ -37,6 +37,7 @@ if is_torch_available():
ViltForImagesAndTextClassification, ViltForImagesAndTextClassification,
ViltForMaskedLM, ViltForMaskedLM,
ViltForQuestionAnswering, ViltForQuestionAnswering,
ViltForTokenClassification,
ViltModel, ViltModel,
) )
from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -173,6 +174,23 @@ class ViltModelTester: ...@@ -173,6 +174,23 @@ class ViltModelTester:
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_len, self.hidden_size) result.last_hidden_state.shape, (self.batch_size, self.expected_seq_len, self.hidden_size)
) )
def create_and_check_for_token_classification(
self,
config,
input_ids,
token_type_ids,
input_mask,
pixel_values,
token_labels,
):
model = ViltForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, pixel_values=pixel_values)
result = model(input_ids, token_type_ids=token_type_ids, pixel_values=pixel_values)
result = model(input_ids, pixel_values=pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -204,6 +222,7 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -204,6 +222,7 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
ViltForQuestionAnswering, ViltForQuestionAnswering,
ViltForImageAndTextRetrieval, ViltForImageAndTextRetrieval,
ViltForMaskedLM, ViltForMaskedLM,
ViltForTokenClassification,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -216,15 +235,12 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -216,15 +235,12 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# if model_class.__name__ == "ViltForNaturalLanguageVisualReasonining":
# inputs_dict["pixel_values"] = floats_tensor([self.model_tester.batch_size, self.model_tester.num_images, self.model_tester.num_channels, self.model_tester.image_size, self.model_tester.image_size])
if return_labels: if return_labels:
if model_class.__name__ == "ViltForQuestionAnswering": if model_class.__name__ == "ViltForQuestionAnswering":
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, self.model_tester.num_labels, device=torch_device self.model_tester.batch_size, self.model_tester.num_labels, device=torch_device
) )
elif model_class.__name__ == "ViltForMaskedLM": elif model_class.__name__ in ["ViltForMaskedLM", "ViltForTokenClassification"]:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
...@@ -246,6 +262,10 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -246,6 +262,10 @@ class ViltModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
return return
...@@ -503,6 +523,10 @@ class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCa ...@@ -503,6 +523,10 @@ class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCa
def test_model(self): def test_model(self):
pass pass
@unittest.skip("We only test the model that takes in multiple images")
def test_for_token_classification(self):
pass
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():
......
...@@ -131,6 +131,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -131,6 +131,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"ViltForQuestionAnswering", "ViltForQuestionAnswering",
"ViltForImagesAndTextClassification", "ViltForImagesAndTextClassification",
"ViltForImageAndTextRetrieval", "ViltForImageAndTextRetrieval",
"ViltForTokenClassification",
"ViltForMaskedLM", "ViltForMaskedLM",
"XGLMEncoder", "XGLMEncoder",
"XGLMDecoder", "XGLMDecoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment