• Pedro Cuenca's avatar
    UNet Flax with FlaxModelMixin (#502) · d8b0e4f4
    Pedro Cuenca authored
    
    
    * First UNet Flax modeling blocks.
    
    Mimic the structure of the PyTorch files.
    The model classes themselves need work, depending on what we do about
    configuration and initialization.
    
    * Remove FlaxUNet2DConfig class.
    
    * ignore_for_config non-config args.
    
    * Implement `FlaxModelMixin`
    
    * Use new mixins for Flax UNet.
    
    For some reason the configuration is not correctly applied; the
    signature of the `__init__` method does not contain all the parameters
    by the time it's inspected in `extract_init_dict`.
    
    * Import `FlaxUNet2DConditionModel` if flax is available.
    
    * Rm unused method `framework`
    
    * Update src/diffusers/modeling_flax_utils.py
    Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
    
    * Indicate types in flax.struct.dataclass as pointed out by @mishig25
    Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
    
    * Fix typo in transformer block.
    
    * make style
    
    * some more changes
    
    * make style
    
    * Add comment
    
    * Update src/diffusers/modeling_flax_utils.py
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    * Rm unneeded comment
    
    * Update docstrings
    
    * correct ignore kwargs
    
    * make style
    
    * Update docstring examples
    
    * Make style
    
    * Style: remove empty line.
    
    * Apply style (after upgrading black from pinned version)
    
    * Remove some commented code and unused imports.
    
    * Add init_weights (not yet in use until #513).
    
    * Trickle down deterministic to blocks.
    
    * Rename q, k, v according to the latest PyTorch version.
    
    Note that weights were exported with the old names, so we need to be
    careful.
    
    * Flax UNet docstrings, default props as in PyTorch.
    
    * Fix minor typos in PyTorch docstrings.
    
    * Use FlaxUNet2DConditionOutput as output from UNet.
    
    * make style
    Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
    Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
    Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    d8b0e4f4
embeddings_flax.py 2.03 KB