Unverified Commit b785ddb6 authored by Junyu Chen's avatar Junyu Chen Committed by GitHub
Browse files

[DC-AE, SANA] fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16 (#10595)



* autoencoder_dc tiling

* add tiling and slicing support in SANA pipelines

* create variables for padding length because the line becomes too long

* add tiling and slicing support in pag SANA pipelines

* revert changes to tile size

* make style

* add vae tiling test

* fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent e8114bd0
......@@ -899,7 +899,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores)
hidden_states = torch.matmul(value, scores.to(value.dtype))
return hidden_states
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment