Unverified Commit 1d3eaa6f authored by Billy Cao's avatar Billy Cao Committed by GitHub
Browse files

Add training support for SigLIP (#31495)

* Add siglip loss function

* Update docs

* Enable training tests
[experimental] enable GC training tests as it has worked for my own data

* Remove test_training* overrides to enable training tests
[run_slow] siglip

* Skip training tests for Siglip text model and ImageClassificationModel
[run_slow] siglip

* Skip GC training tests for SiglipForImageClassification

* Explicitly skip training tests for SiglipVisionModel
Add skip reason for training tests for SiglipTextModel

* Remove copied from to fix CI
parent 15560252
...@@ -27,7 +27,7 @@ The abstract from the paper is the following: ...@@ -27,7 +27,7 @@ The abstract from the paper is the following:
## Usage tips ## Usage tips
- Usage of SigLIP is similar to [CLIP](clip). The main difference is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax. - Usage of SigLIP is similar to [CLIP](clip). The main difference is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax.
- Training is not yet supported. If you want to fine-tune SigLIP or train from scratch, refer to the loss function from [OpenCLIP](https://github.com/mlfoundations/open_clip/blob/73ad04ae7fb93ede1c02dc9040a828634cb1edf1/src/open_clip/loss.py#L307), which leverages various `torch.distributed` utilities. - Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup.
- When using the standalone [`SiglipTokenizer`] or [`SiglipProcessor`], make sure to pass `padding="max_length"` as that's how the model was trained. - When using the standalone [`SiglipTokenizer`] or [`SiglipProcessor`], make sure to pass `padding="max_length"` as that's how the model was trained.
- To get the same results as the pipeline, a prompt template of "This is a photo of {label}." should be used. - To get the same results as the pipeline, a prompt template of "This is a photo of {label}." should be used.
......
...@@ -1234,7 +1234,12 @@ class SiglipModel(SiglipPreTrainedModel): ...@@ -1234,7 +1234,12 @@ class SiglipModel(SiglipPreTrainedModel):
loss = None loss = None
if return_loss: if return_loss:
raise NotImplementedError("SigLIP loss to be implemented") # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
if not return_dict: if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
......
...@@ -335,27 +335,19 @@ class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -335,27 +335,19 @@ class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip @unittest.skip(reason="SiglipTextModel does not support standalone training")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training
def test_training(self): def test_training(self):
pass pass
@unittest.skip @unittest.skip(reason="SiglipTextModel does not support standalone training")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip( @unittest.skip(reason="SiglipTextModel does not support standalone training")
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant
def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant(self):
pass pass
@unittest.skip( @unittest.skip(reason="SiglipTextModel does not support standalone training")
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant_false
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
...@@ -481,22 +473,6 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -481,22 +473,6 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_get_set_embeddings(self): def test_model_get_set_embeddings(self):
pass pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self): def test_initialization(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment