Unverified Commit 7fe47596 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Allow diffusers to load with Flax, w/o PyTorch (#6272)

parent 59d1caa2
...@@ -89,7 +89,7 @@ def is_compiled_module(module) -> bool: ...@@ -89,7 +89,7 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor: def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here: This version of the method comes from here:
...@@ -121,8 +121,8 @@ def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tens ...@@ -121,8 +121,8 @@ def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tens
def apply_freeu( def apply_freeu(
resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:
"""Applies the FreeU mechanism as introduced in https: """Applies the FreeU mechanism as introduced in https:
//arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment