Unverified Commit 0902449e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add `from_pt` argument in `.from_pretrained` (#527)

* first commit:

- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file

* small nit

- fix a small nit - to not enter in the second if condition

* major changes

- modify FlaxUnet modules
- first conversion script
- more keys to be matched

* keys match

- now all keys match
- change module names for correct matching
- upsample module name changed

* working v1

- test pass with atol and rtol= `4e-02`

* replace unsued arg

* make quality

* add small docstring

* add more comments

- add TODO for embedding layers

* small change

- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array

* add more conditions on conversion

- add better test to check for keys conversion

* make shapes consistent

- output `img_w x img_h x n_channels` from the VAE

* Revert "make shapes consistent"

This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.

* fix unet shape

- channels first!
parent ca749513
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - Flax general utilities."""
import re
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from .utils import logging
logger = logging.get_logger(__name__)
def rename_key(key):
regex = r"\w+[.]\d+"
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key
#####################
# PyTorch => Flax #
#####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
# embedding
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
return renamed_pt_tuple_key, pt_tensor
# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight":
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# Step 1: Convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# Step 2: Since the model is stateless, get random Flax params
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
random_flax_state_dict = flatten_dict(random_flax_params)
flax_state_dict = {}
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
renamed_pt_key = rename_key(pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))
# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)
...@@ -27,7 +27,8 @@ from huggingface_hub import hf_hub_download ...@@ -27,7 +27,8 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from .modeling_utils import WEIGHTS_NAME from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import WEIGHTS_NAME, load_state_dict
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
...@@ -245,6 +246,8 @@ class FlaxModelMixin: ...@@ -245,6 +246,8 @@ class FlaxModelMixin:
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git. identifier allowed by git.
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
...@@ -272,6 +275,7 @@ class FlaxModelMixin: ...@@ -272,6 +275,7 @@ class FlaxModelMixin:
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
...@@ -306,10 +310,16 @@ class FlaxModelMixin: ...@@ -306,10 +310,16 @@ class FlaxModelMixin:
# Load from a Flax checkpoint # Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
"but there is a file for PyTorch weights." " using `from_pt=True`."
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
...@@ -320,7 +330,7 @@ class FlaxModelMixin: ...@@ -320,7 +330,7 @@ class FlaxModelMixin:
try: try:
model_file = hf_hub_download( model_file = hf_hub_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME, filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -370,6 +380,13 @@ class FlaxModelMixin: ...@@ -370,6 +380,13 @@ class FlaxModelMixin:
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
if from_pt:
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)
# Step 2: Convert the weights
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
else:
try: try:
with open(model_file, "rb") as state_f: with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read()) state = from_bytes(cls, state_f.read())
......
...@@ -32,7 +32,7 @@ class FlaxAttentionBlock(nn.Module): ...@@ -32,7 +32,7 @@ class FlaxAttentionBlock(nn.Module):
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
def reshape_heads_to_batch_dim(self, tensor): def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
...@@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module):
def setup(self): def setup(self):
# self attention # self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention # cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
...@@ -93,12 +93,12 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -93,12 +93,12 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True): def __call__(self, hidden_states, context, deterministic=True):
# self attention # self attention
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
# cross attention # cross attention
residual = hidden_states residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
# feed forward # feed forward
...@@ -167,14 +167,28 @@ class FlaxGluFeedForward(nn.Module): ...@@ -167,14 +167,28 @@ class FlaxGluFeedForward(nn.Module):
dropout: float = 0.0 dropout: float = 0.0
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self):
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states
class FlaxGEGLU(nn.Module):
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
inner_dim = self.dim * 4 inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True): def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states) hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu) return hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states
...@@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors # init input tensors
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32) sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32) timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
...@@ -214,10 +214,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -214,10 +214,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor. When returning a tuple, the first element is the sample tensor.
""" """
# 1. time # 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb) t_emb = self.time_embedding(t_emb)
# 2. pre-process # 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample) sample = self.conv_in(sample)
# 3. down # 3. down
...@@ -251,6 +258,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -251,6 +258,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample = self.conv_norm_out(sample) sample = self.conv_norm_out(sample)
sample = nn.silu(sample) sample = nn.silu(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
sample = jnp.transpose(sample, (0, 3, 1, 2))
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -55,7 +55,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -55,7 +55,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
self.attentions = attentions self.attentions = attentions
if self.add_downsample: if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
output_states = () output_states = ()
...@@ -66,7 +66,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -66,7 +66,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
output_states += (hidden_states,) output_states += (hidden_states,)
if self.add_downsample: if self.add_downsample:
hidden_states = self.downsample(hidden_states) hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,) output_states += (hidden_states,)
return hidden_states, output_states return hidden_states, output_states
...@@ -96,7 +96,7 @@ class FlaxDownBlock2D(nn.Module): ...@@ -96,7 +96,7 @@ class FlaxDownBlock2D(nn.Module):
self.resnets = resnets self.resnets = resnets
if self.add_downsample: if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, deterministic=True): def __call__(self, hidden_states, temb, deterministic=True):
output_states = () output_states = ()
...@@ -106,7 +106,7 @@ class FlaxDownBlock2D(nn.Module): ...@@ -106,7 +106,7 @@ class FlaxDownBlock2D(nn.Module):
output_states += (hidden_states,) output_states += (hidden_states,)
if self.add_downsample: if self.add_downsample:
hidden_states = self.downsample(hidden_states) hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,) output_states += (hidden_states,)
return hidden_states, output_states return hidden_states, output_states
...@@ -151,7 +151,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -151,7 +151,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
self.attentions = attentions self.attentions = attentions
if self.add_upsample: if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -164,7 +164,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -164,7 +164,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
if self.add_upsample: if self.add_upsample:
hidden_states = self.upsample(hidden_states) hidden_states = self.upsamplers_0(hidden_states)
return hidden_states return hidden_states
...@@ -196,7 +196,7 @@ class FlaxUpBlock2D(nn.Module): ...@@ -196,7 +196,7 @@ class FlaxUpBlock2D(nn.Module):
self.resnets = resnets self.resnets = resnets
if self.add_upsample: if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
for resnet in self.resnets: for resnet in self.resnets:
...@@ -208,7 +208,7 @@ class FlaxUpBlock2D(nn.Module): ...@@ -208,7 +208,7 @@ class FlaxUpBlock2D(nn.Module):
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
if self.add_upsample: if self.add_upsample:
hidden_states = self.upsample(hidden_states) hidden_states = self.upsamplers_0(hidden_states)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment