Unverified Commit c0c11683 authored by Eliseu Silva's avatar Eliseu Silva Committed by GitHub
Browse files

Make passing the IP Adapter mask to the attention mechanism optional (#10346)

Make passing the IP Adapter mask to the attention mechanism optional if there is no need to apply it to a given IP Adapter.
parent 6dfaec34
...@@ -4839,6 +4839,8 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -4839,6 +4839,8 @@ class IPAdapterAttnProcessor(nn.Module):
) )
else: else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4: if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError( raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape " "Each element of the ip_adapter_masks array should be a tensor with shape "
...@@ -5056,6 +5058,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -5056,6 +5058,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
) )
else: else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4: if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError( raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape " "Each element of the ip_adapter_masks array should be a tensor with shape "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment