"...kernels/git@developer.sourcefind.cn:change/sglang.git" did not exist on "295895120df4cdf133e07a93e389c92e4fb7d48d"
Unverified Commit 4cb5ffa9 authored by Anahita Bhiwandiwalla's avatar Anahita Bhiwandiwalla Committed by GitHub
Browse files

Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval (#21684)



* Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval

* minor fix return_dict

* implement test for loss computation

---------
Co-authored-by: default avatarTiep Le <97980157+tileintel@users.noreply.github.com>
Co-authored-by: default avatarTiep Le <tiep.le@intel.com>
parent 7f4f8b97
......@@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN, QuickGELUActivation
from ...modeling_outputs import (
......@@ -1535,8 +1536,10 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
labels: Optional[torch.LongTensor] = None,
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels are currently not supported.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
Returns:
Examples:
......@@ -1580,11 +1583,17 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
)
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
if not return_dict:
return tuple(mlm_logits)
output = tuple(mlm_logits)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=mlm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
......@@ -1627,8 +1636,9 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
labels: Optional[torch.LongTensor] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels are currently not supported.
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
The pairs with 0 will be skipped for calculation.
Returns:
Examples:
......@@ -1673,11 +1683,17 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
logits = self.itm_score(pooler_output)
itm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
itm_loss = loss_fct(logits, labels)
if not return_dict:
return tuple(logits)
output = tuple(logits)
return ((itm_loss,) + output) if itm_loss is not None else output
return SequenceClassifierOutput(
loss=None,
loss=itm_loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
......
......@@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape)
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
# verify loss
inputs["labels"] = torch.ones(1, dtype=torch.long, device=torch_device)
inputs = inputs.to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
self.assertAlmostEqual(outputs.loss.item(), 0.5108, places=4)
@slow
def test_masked_language_modeling(self):
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
......@@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
# verify predicted word
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
self.assertTrue(processor.decode([predicted_id]) == " cats")
# verify loss
inputs["labels"] = inputs["input_ids"].clone()
inputs = inputs.to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
@require_torch
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
class BridgeTowerModelTrainingTest(unittest.TestCase):
all_training_supported_model_classes = (
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
)
def setUp(self):
self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
def _prepare_inputs_for_training(self, model_class):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if model_class == BridgeTowerForMaskedLM:
inputs_dict["labels"] = inputs_dict["input_ids"]
elif model_class == BridgeTowerForImageAndTextRetrieval:
inputs_dict["labels"] = ids_tensor([1], 2)
return config, inputs_dict
def _get_non_used_layer_names(self, model_class):
non_used_layer_names = ["text_model.pooler"]
if model_class == BridgeTowerForMaskedLM:
non_used_layer_names = non_used_layer_names + [
"cross_modal_image_layers.5",
"cross_modal_image_pooler",
"cross_modal_text_pooler",
]
return non_used_layer_names
def _is_layer_used(self, model_class, layer_name):
non_used_layer_names = self._get_non_used_layer_names(model_class)
for non_used_layer_name in non_used_layer_names:
if non_used_layer_name in layer_name:
return False
return True
def test_training(self):
for model_class in self.all_training_supported_model_classes:
config, inputs_dict = self._prepare_inputs_for_training(model_class)
model = model_class(config)
model.to(torch_device)
model.train()
loss = model(**inputs_dict).loss
loss.backward()
# verify the gradients of used layers' weight are not None
for name, param in model.named_parameters():
if self._is_layer_used(model_class, name):
self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment