• Pedro Cuenca's avatar
    Fix import with Flax but without PyTorch (#688) · 688031c5
    Pedro Cuenca authored
    * Don't use `load_state_dict` if torch is not installed.
    
    * Define `SchedulerOutput` to use torch or flax arrays.
    
    * Don't import LMSDiscreteScheduler without torch.
    
    * Create distinct FlaxSchedulerOutput.
    
    * Additional changes required for FlaxSchedulerMixin
    
    * Do not import torch pipelines in Flax.
    
    * Revert "Define `SchedulerOutput` to use torch or flax arrays."
    
    This reverts commit f653140134b74d9ffec46d970eb46925fe3a409d.
    
    * Prefix Flax scheduler outputs for consistency.
    
    * make style
    
    * FlaxSchedulerOutput is now a dataclass.
    
    * Don't use f-string without placeholders.
    
    * Add blank line.
    
    * Style (docstrings)
    688031c5
modeling_flax_utils.py 26.6 KB