Unverified Commit 8c6b47cf authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

`AttentionProcessor.group_norm` num_channels should be `query_dim` (#3046)

* `AttentionProcessor.group_norm` num_channels should be `query_dim`

The group_norm on the attention processor should really norm the number
of channels in the query _not_ the inner dim. This wasn't caught before
because the group_norm is only used by the added kv attention processors
and the added kv attention processors are only used by the karlo models
which are configured such that the inner dim is the same as the query
dim.

* add_{k,v}_proj should be projecting to inner_dim
parent 67ec9cf5
......@@ -81,7 +81,7 @@ class Attention(nn.Module):
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
......@@ -93,8 +93,8 @@ class Attention(nn.Module):
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment