Unverified Commit 0071478d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

allow attention processors to have different signatures (#6915)



add
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 7c8cab31
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from importlib import import_module from importlib import import_module
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
...@@ -509,6 +510,15 @@ class Attention(nn.Module): ...@@ -509,6 +510,15 @@ class Attention(nn.Module):
# The `Attention` class can call different attention processors / attention functions # The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class # here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty # For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor( return self.processor(
self, self,
hidden_states, hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment