1. 20 Sep, 2022 1 commit
    • Younes Belkada's avatar
      Add `from_pt` argument in `.from_pretrained` (#527) · 0902449e
      Younes Belkada authored
      * first commit:
      
      - add `from_pt` argument in `from_pretrained` function
      - add `modeling_flax_pytorch_utils.py` file
      
      * small nit
      
      - fix a small nit - to not enter in the second if condition
      
      * major changes
      
      - modify FlaxUnet modules
      - first conversion script
      - more keys to be matched
      
      * keys match
      
      - now all keys match
      - change module names for correct matching
      - upsample module name changed
      
      * working v1
      
      - test pass with atol and rtol= `4e-02`
      
      * replace unsued arg
      
      * make quality
      
      * add small docstring
      
      * add more comments
      
      - add TODO for embedding layers
      
      * small change
      
      - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
      
      * add more conditions on conversion
      
      - add better test to check for keys conversion
      
      * make shapes consistent
      
      - output `img_w x img_h x n_channels` from the VAE
      
      * Revert "make shapes consistent"
      
      This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.
      
      * fix unet shape
      
      - channels first!
      0902449e