Commit 8328a2d8 authored by comfyanonymous's avatar comfyanonymous
Browse files

Let hunyuan dit work with all prompt lengths.

parent afe732be
......@@ -16,6 +16,7 @@ class AttentionPool(nn.Module):
self.embed_dim = embed_dim
def forward(self, x):
x = x[:,:self.positional_embedding.shape[0] - 1]
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment