• Younes Belkada's avatar
    Add `from_pt` argument in `.from_pretrained` (#527) · 0902449e
    Younes Belkada authored
    * first commit:
    
    - add `from_pt` argument in `from_pretrained` function
    - add `modeling_flax_pytorch_utils.py` file
    
    * small nit
    
    - fix a small nit - to not enter in the second if condition
    
    * major changes
    
    - modify FlaxUnet modules
    - first conversion script
    - more keys to be matched
    
    * keys match
    
    - now all keys match
    - change module names for correct matching
    - upsample module name changed
    
    * working v1
    
    - test pass with atol and rtol= `4e-02`
    
    * replace unsued arg
    
    * make quality
    
    * add small docstring
    
    * add more comments
    
    - add TODO for embedding layers
    
    * small change
    
    - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
    
    * add more conditions on conversion
    
    - add better test to check for keys conversion
    
    * make shapes consistent
    
    - output `img_w x img_h x n_channels` from the VAE
    
    * Revert "make shapes consistent"
    
    This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.
    
    * fix unet shape
    
    - channels first!
    0902449e
modeling_flax_pytorch_utils.py 4.49 KB