Unverified Commit 872e6be0 authored by Sachin Abeywardana's avatar Sachin Abeywardana Committed by GitHub
Browse files

Update clip loss calculation (#13217)



* Update clip loss calculation

Hello, I'm the author of the blog you took the snippet from. I think this way of calculating is possibly slightly more accurate for calculation.

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 0a22335e
...@@ -61,14 +61,13 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -61,14 +61,13 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# contrastive loss function, adapted from # contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
return -neg_ce.mean()
def clip_loss(similarity: torch.Tensor) -> torch.Tensor: def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity, dim=0) caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity, dim=1) image_loss = contrastive_loss(similarity.T)
return (caption_loss + image_loss) / 2.0 return (caption_loss + image_loss) / 2.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment