"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e660424717daf47d4e511a78b9dda230a2f2a602"
Unverified Commit 4cb5ffa9 authored by Anahita Bhiwandiwalla's avatar Anahita Bhiwandiwalla Committed by GitHub
Browse files

Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval (#21684)



* Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval

* minor fix return_dict

* implement test for loss computation

---------
Co-authored-by: default avatarTiep Le <97980157+tileintel@users.noreply.github.com>
Co-authored-by: default avatarTiep Le <tiep.le@intel.com>
parent 7f4f8b97
...@@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union ...@@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN, QuickGELUActivation from ...activations import ACT2FN, QuickGELUActivation
from ...modeling_outputs import ( from ...modeling_outputs import (
...@@ -1535,8 +1536,10 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): ...@@ -1535,8 +1536,10 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]: ) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels are currently not supported. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
Returns: Returns:
Examples: Examples:
...@@ -1580,11 +1583,17 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): ...@@ -1580,11 +1583,17 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
) )
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0]) mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
return tuple(mlm_logits) output = tuple(mlm_logits)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput( return MaskedLMOutput(
loss=masked_lm_loss,
logits=mlm_logits, logits=mlm_logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
...@@ -1627,8 +1636,9 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel): ...@@ -1627,8 +1636,9 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
Labels are currently not supported. Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
The pairs with 0 will be skipped for calculation.
Returns: Returns:
Examples: Examples:
...@@ -1673,11 +1683,17 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel): ...@@ -1673,11 +1683,17 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
logits = self.itm_score(pooler_output) logits = self.itm_score(pooler_output)
itm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
itm_loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
return tuple(logits) output = tuple(logits)
return ((itm_loss,) + output) if itm_loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=None, loss=itm_loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
...@@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase): ...@@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item()) self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
# verify loss
inputs["labels"] = torch.ones(1, dtype=torch.long, device=torch_device)
inputs = inputs.to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
self.assertAlmostEqual(outputs.loss.item(), 0.5108, places=4)
@slow @slow
def test_masked_language_modeling(self): def test_masked_language_modeling(self):
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device) model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
...@@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase): ...@@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
# verify predicted word # verify predicted word
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4] predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
self.assertTrue(processor.decode([predicted_id]) == " cats") self.assertTrue(processor.decode([predicted_id]) == " cats")
# verify loss
inputs["labels"] = inputs["input_ids"].clone()
inputs = inputs.to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
@require_torch
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
class BridgeTowerModelTrainingTest(unittest.TestCase):
all_training_supported_model_classes = (
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
)
def setUp(self):
self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
def _prepare_inputs_for_training(self, model_class):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if model_class == BridgeTowerForMaskedLM:
inputs_dict["labels"] = inputs_dict["input_ids"]
elif model_class == BridgeTowerForImageAndTextRetrieval:
inputs_dict["labels"] = ids_tensor([1], 2)
return config, inputs_dict
def _get_non_used_layer_names(self, model_class):
non_used_layer_names = ["text_model.pooler"]
if model_class == BridgeTowerForMaskedLM:
non_used_layer_names = non_used_layer_names + [
"cross_modal_image_layers.5",
"cross_modal_image_pooler",
"cross_modal_text_pooler",
]
return non_used_layer_names
def _is_layer_used(self, model_class, layer_name):
non_used_layer_names = self._get_non_used_layer_names(model_class)
for non_used_layer_name in non_used_layer_names:
if non_used_layer_name in layer_name:
return False
return True
def test_training(self):
for model_class in self.all_training_supported_model_classes:
config, inputs_dict = self._prepare_inputs_for_training(model_class)
model = model_class(config)
model.to(torch_device)
model.train()
loss = model(**inputs_dict).loss
loss.backward()
# verify the gradients of used layers' weight are not None
for name, param in model.named_parameters():
if self._is_layer_used(model_class, name):
self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment