"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cdcc01be0ead8e3473ff88b95b8c53755a60750f"
Unverified Commit 720dbfc9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

Compute embedding distances with torch.cdist (#1459)

small but mighty
parent 513fc681
...@@ -290,15 +290,10 @@ class VectorQuantizer(nn.Module): ...@@ -290,15 +290,10 @@ class VectorQuantizer(nn.Module):
# reshape z -> (batch, height, width, channel) and flatten # reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous() z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim) z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = ( # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
torch.sum(z_flattened**2, dim=1, keepdim=True) min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape) z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None perplexity = None
min_encodings = None min_encodings = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment