poolers.py 1.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention #TODO


class AttentionPool(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
        self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
        self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
        self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
        self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
        self.num_heads = num_heads
        self.embed_dim = embed_dim

    def forward(self, x):
19
        x = x[:,:self.positional_embedding.shape[0] - 1]
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        x = x.permute(1, 0, 2)  # NLC -> LNC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (L+1)NC
        x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device)  # (L+1)NC

        q = self.q_proj(x[:1])
        k = self.k_proj(x)
        v = self.v_proj(x)

        batch_size = q.shape[1]
        head_dim = self.embed_dim // self.num_heads
        q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
        k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
        v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)

        attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)

        attn_output = self.c_proj(attn_output)
        return attn_output.squeeze(0)