Unverified Commit 4f0d01d3 authored by psychedelicious's avatar psychedelicious Committed by GitHub
Browse files

type `get_attention_scores` as optional in `get_attention_scores` (#9075)

`None` is valid for `get_attention_scores`, should be typed as such
parent 3dc10a53
...@@ -539,7 +539,7 @@ class Attention(nn.Module): ...@@ -539,7 +539,7 @@ class Attention(nn.Module):
return tensor return tensor
def get_attention_scores( def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Compute the attention scores. Compute the attention scores.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment