"docs/vscode:/vscode.git/clone" did not exist on "f685981ed0e63af625db99b42863d9cd8a245176"
Unverified Commit c0c11683 authored by Eliseu Silva's avatar Eliseu Silva Committed by GitHub
Browse files

Make passing the IP Adapter mask to the attention mechanism optional (#10346)

Make passing the IP Adapter mask to the attention mechanism optional if there is no need to apply it to a given IP Adapter.
parent 6dfaec34
...@@ -4839,6 +4839,8 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -4839,6 +4839,8 @@ class IPAdapterAttnProcessor(nn.Module):
) )
else: else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4: if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError( raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape " "Each element of the ip_adapter_masks array should be a tensor with shape "
...@@ -5056,6 +5058,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -5056,6 +5058,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
) )
else: else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4: if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError( raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape " "Each element of the ip_adapter_masks array should be a tensor with shape "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment