"tests/vscode:/vscode.git/clone" did not exist on "3918d6a9d67e79c45746607d3ca726ddd641a3d1"
Commit 3d1b9667 authored by wangwei990215's avatar wangwei990215
Browse files

更新diffusers文件夹

parent b6a53272
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import deprecate
from .transformers.dual_transformer_2d import DualTransformer2DModel
class DualTransformer2DModel(DualTransformer2DModel):
deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead."
deprecate("DualTransformer2DModel", "0.29", deprecation_message)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from ..utils import deprecate
from .activations import get_activation
from .attention_processor import Attention
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
if set_W_to_weight:
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
def forward(self, x):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
class SinusoidalPositionalEmbedding(nn.Module):
"""Apply positional information to a sequence of embeddings.
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
them
Args:
embed_dim: (int): Dimension of the positional embedding.
max_seq_length: Maximum sequence length to apply positional embeddings
"""
def __init__(self, embed_dim: int, max_seq_length: int = 32):
super().__init__()
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe = torch.zeros(1, max_seq_length, embed_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
_, seq_length, _ = x.shape
x = x + self.pe[:, :seq_length]
return x
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""
def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()
self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim
self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)
def forward(self, index):
emb = self.emb(index)
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)
pos_emb = height_emb + width_emb
# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb
class LabelEmbedding(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
Args:
num_classes (`int`): The number of classes.
hidden_size (`int`): The size of the vector embeddings.
dropout_prob (`float`): The probability of dropping a label.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = torch.tensor(force_drop_ids == 1)
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class TextImageProjection(nn.Module):
def __init__(
self,
text_embed_dim: int = 1024,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 10,
):
super().__init__()
self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
batch_size = text_embeds.shape[0]
# image
image_text_embeds = self.image_embeds(image_embeds)
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
# text
text_embeds = self.text_proj(text_embeds)
return torch.cat([image_text_embeds, text_embeds], dim=1)
class ImageProjection(nn.Module):
def __init__(
self,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 32,
):
super().__init__()
self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
batch_size = image_embeds.shape[0]
# image
image_embeds = self.image_embeds(image_embeds)
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
return self.norm(self.ff(image_embeds))
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
def forward(self, timestep, class_labels, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
class_labels = self.class_embedder(class_labels) # (N, D)
conditioning = timesteps_emb + class_labels # (N, D)
return conditioning
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
self.norm1 = nn.LayerNorm(encoder_dim)
self.pool = AttentionPooling(num_heads, encoder_dim)
self.proj = nn.Linear(encoder_dim, time_embed_dim)
self.norm2 = nn.LayerNorm(time_embed_dim)
def forward(self, hidden_states):
hidden_states = self.norm1(hidden_states)
hidden_states = self.pool(hidden_states)
hidden_states = self.proj(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class TextImageTimeEmbedding(nn.Module):
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
self.text_norm = nn.LayerNorm(time_embed_dim)
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
# text
time_text_embeds = self.text_proj(text_embeds)
time_text_embeds = self.text_norm(time_text_embeds)
# image
time_image_embeds = self.image_proj(image_embeds)
return time_image_embeds + time_text_embeds
class ImageTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
def forward(self, image_embeds: torch.FloatTensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
return time_image_embeds
class ImageHintTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 32, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.SiLU(),
nn.Conv2d(32, 96, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(96, 96, 3, padding=1),
nn.SiLU(),
nn.Conv2d(96, 256, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(256, 4, 3, padding=1),
)
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
hint = self.input_hint_block(hint)
return time_image_embeds, hint
class AttentionPooling(nn.Module):
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
def __init__(self, num_heads, embed_dim, dtype=None):
super().__init__()
self.dtype = dtype
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
self.num_heads = num_heads
self.dim_per_head = embed_dim // self.num_heads
def forward(self, x):
bs, length, width = x.size()
def shape(x):
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
x = x.transpose(1, 2)
return x
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
# (bs*n_heads, class_token_length, dim_per_head)
q = shape(self.q_proj(class_token))
# (bs*n_heads, length+class_token_length, dim_per_head)
k = shape(self.k_proj(x))
v = shape(self.v_proj(x))
# (bs*n_heads, class_token_length, length+class_token_length):
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
# (bs*n_heads, dim_per_head, class_token_length)
a = torch.einsum("bts,bcs->bct", weight, v)
# (bs, length+1, width)
a = a.reshape(bs, -1, 1).transpose(1, 2)
return a[:, 0, :] # cls_token
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""
batch_size, num_boxes = box.shape[:2]
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
return emb
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
if feature_type == "text-only":
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(
self,
boxes,
masks,
positive_embeddings=None,
phrases_masks=None,
image_masks=None,
phrases_embeddings=None,
image_embeddings=None,
):
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
# learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
# positionet with text only information
if positive_embeddings is not None:
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
# positionet with text and image infomation
else:
phrases_masks = phrases_masks.unsqueeze(-1)
image_masks = image_masks.unsqueeze(-1)
# learnable null embedding
text_null = self.null_text_feature.view(1, 1, -1)
image_null = self.null_image_feature.view(1, 1, -1)
# replace padding with learnable null embedding
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
objs = torch.cat([objs_text, objs_image], dim=1)
return objs
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, num_tokens=120):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
----
embed_dims (int): The feature dimension. Defaults to 768.
output_dims (int): The number of output channels, that is the same
number of the channels in the
`unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280.
depth (int): The number of blocks. Defaults to 8.
dim_head (int): The number of head channels. Defaults to 64.
heads (int): Parallel attention heads. Defaults to 16.
num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
"""
def __init__(
self,
embed_dims: int = 768,
output_dims: int = 1024,
hidden_dims: int = 1280,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward # Lazy import to avoid circular import
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
self.proj_in = nn.Linear(embed_dims, hidden_dims)
self.proj_out = nn.Linear(hidden_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
nn.LayerNorm(hidden_dims),
nn.LayerNorm(hidden_dims),
Attention(
query_dim=hidden_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
]
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
----
x (torch.Tensor): Input Tensor.
Returns:
-------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for ln0, ln1, attn, ff in self.layers:
residual = latents
encoder_hidden_states = ln0(x)
latents = ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = attn(latents, encoder_hidden_states) + residual
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
def forward(self, image_embeds: List[torch.FloatTensor]):
projected_image_embeds = []
# currently, we accept `image_embeds` as
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
if not isinstance(image_embeds, list):
deprecation_message = (
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
)
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
image_embeds = [image_embeds.unsqueeze(1)]
if len(image_embeds) != len(self.image_projection_layers):
raise ValueError(
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
)
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
image_embed = image_projection_layer(image_embed)
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
projected_image_embeds.append(image_embed)
return projected_image_embeds
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import flax.linen as nn
import jax.numpy as jnp
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
embedding_dim: int,
freq_shift: float = 1,
min_timescale: float = 1,
max_timescale: float = 1.0e4,
flip_sin_to_cos: bool = False,
scale: float = 1.0,
) -> jnp.ndarray:
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
Returns:
a Tensor of timing signals [N, num_channels]
"""
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
num_timescales = float(embedding_dim // 2)
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
# scale embeddings
scaled_time = scale * emb
if flip_sin_to_cos:
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
else:
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
return signal
class FlaxTimestepEmbedding(nn.Module):
r"""
Time step Embedding Module. Learns embeddings for input time steps.
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, temb):
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
return temb
class FlaxTimesteps(nn.Module):
r"""
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
"""
dim: int = 32
flip_sin_to_cos: bool = False
freq_shift: float = 1
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# IMPORTANT: #
###################################################################
# ----------------------------------------------------------------#
# This file is deprecated and will be removed soon #
# (as soon as PEFT will become a required dependency for LoRA) #
# ----------------------------------------------------------------#
###################################################################
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils.import_utils import is_transformers_available
if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def text_encoder_attn_modules(text_encoder):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
return attn_modules
def text_encoder_mlp_modules(text_encoder):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
return mlp_modules
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class PatchedLoraProjection(torch.nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
super().__init__()
from ..models.lora import LoRALinearLayer
self.regular_linear_layer = regular_linear_layer
device = self.regular_linear_layer.weight.device
if dtype is None:
dtype = self.regular_linear_layer.weight.dtype
self.lora_linear_layer = LoRALinearLayer(
self.regular_linear_layer.in_features,
self.regular_linear_layer.out_features,
network_alpha=network_alpha,
device=device,
dtype=dtype,
rank=rank,
)
self.lora_scale = lora_scale
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
# when saving the whole text encoder model and when LoRA is unloaded or fused
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if self.lora_linear_layer is None:
return self.regular_linear_layer.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_linear_layer is None:
return
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
w_orig = self.regular_linear_layer.weight.data.float()
w_up = self.lora_linear_layer.up.weight.data.float()
w_down = self.lora_linear_layer.down.weight.data.float()
if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_linear_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self.lora_scale = lora_scale
def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, input):
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
class LoRALinearLayer(nn.Module):
r"""
A linear layer that is used with LoRA.
Parameters:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
rank (`int`, `optional`, defaults to 4):
The rank of the LoRA layer.
network_alpha (`float`, `optional`, defaults to `None`):
The value of the network alpha used for stable learning and preventing underflow. This value has the same
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
device (`torch.device`, `optional`, defaults to `None`):
The device to use for the layer's weights.
dtype (`torch.dtype`, `optional`, defaults to `None`):
The dtype to use for the layer's weights.
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
network_alpha: Optional[float] = None,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRALinearLayer", "1.0.0", deprecation_message)
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
self.out_features = out_features
self.in_features = in_features
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRAConv2dLayer(nn.Module):
r"""
A convolutional layer that is used with LoRA.
Parameters:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
rank (`int`, `optional`, defaults to 4):
The rank of the LoRA layer.
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
The kernel size of the convolution.
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
The stride of the convolution.
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
The padding of the convolution.
network_alpha (`float`, `optional`, defaults to `None`):
The value of the network alpha used for stable learning and preventing underflow. This value has the same
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
stride: Union[int, Tuple[int, int]] = (1, 1),
padding: Union[int, Tuple[int, int], str] = 0,
network_alpha: Optional[float] = None,
):
super().__init__()
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRACompatibleConv(nn.Conv2d):
"""
A convolutional layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("set_lora_layer", "1.0.0", deprecation_message)
self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
if self.lora_layer is None:
return
dtype, device = self.weight.data.dtype, self.weight.data.device
w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + (lora_scale * fusion)
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data
dtype, device = fused_weight.data.dtype, fused_weight.data.device
self.w_up = self.w_up.to(device=device).float()
self.w_down = self.w_down.to(device).float()
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
fusion = fusion.reshape((fused_weight.shape))
unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
if self.padding_mode != "zeros":
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
padding = (0, 0)
else:
padding = self.padding
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
)
if self.lora_layer is None:
return original_outputs
else:
return original_outputs + (scale * self.lora_layer(hidden_states))
class LoRACompatibleLinear(nn.Linear):
"""
A Linear layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("set_lora_layer", "1.0.0", deprecation_message)
self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
if self.lora_layer is None:
return
dtype, device = self.weight.data.dtype, self.weight.data.device
w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
if safe_fusing and torch.isnan(fused_weight).any().item():
raise ValueError(
"This LoRA weight seems to be broken. "
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
"LoRA weights will not be fused."
)
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
self.lora_layer = None
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data
dtype, device = fused_weight.dtype, fused_weight.device
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
if self.lora_layer is None:
out = super().forward(hidden_states)
return out
else:
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
return out
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - Flax general utilities."""
import re
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from ..utils import logging
logger = logging.get_logger(__name__)
def rename_key(key):
regex = r"\w+[.]\d+"
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key
#####################
# PyTorch => Flax #
#####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (
("to_out_0", "proj_attn"),
("to_k", "key"),
("to_v", "value"),
("to_q", "query"),
):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = "kernel" if weight_name == "weight" else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
# embedding
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
return renamed_pt_tuple_key, pt_tensor
# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight":
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# Step 1: Convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# Step 2: Since the model is stateless, get random Flax params
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
random_flax_state_dict = flatten_dict(random_flax_params)
flax_state_dict = {}
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
renamed_pt_key = rename_key(pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))
# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pickle import UnpicklingError
from typing import Any, Dict, Union
import jax
import jax.numpy as jnp
import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from requests import HTTPError
from .. import __version__, is_torch_available
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
PushToHubMixin,
logging,
)
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
logger = logging.get_logger(__name__)
class FlaxModelMixin(PushToHubMixin):
r"""
Base class for all Flax models.
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
saving models.
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_flax_internal_args = ["name", "parent", "dtype"]
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_map(conditional_cast, params)
flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)
for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)
return unflatten_dict(flat_params)
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place.
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> params = model.to_bf16(params)
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_bf16(params, mask)
```"""
return self._cast_floating_to(params, jnp.bfloat16, mask)
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> params = model.to_f16(params)
>>> # now cast back to fp32
>>> params = model.to_fp32(params)
```"""
return self._cast_floating_to(params, jnp.float32, mask)
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
`params` in place.
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
for params you want to cast, and `False` for those you want to skip.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # By default, the model params will be in fp32, to cast these to float16
>>> params = model.to_fp16(params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_fp16(params, mask)
```"""
return self._cast_floating_to(params, jnp.float16, mask)
def init_weights(self, rng: jax.Array) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32,
*model_args,
**kwargs,
):
r"""
Instantiate a pretrained Flax model from a pretrained model configuration.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxModelMixin.save_pretrained`].
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified, all the computation will be performed with the given `dtype`.
<Tip>
This only specifies the dtype of the *computation* and does not influence the dtype of model
parameters.
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
[`~FlaxModelMixin.to_bf16`].
</Tip>
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments are passed to the underlying model's `__init__` method.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it is loaded) and initiate the model (for
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
model's `__init__` method (we assume all relevant updates to the configuration have already been
done).
- If a configuration is not provided, `kwargs` are first passed to the configuration class
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
Remaining keys that do not correspond to any configuration attribute are passed to the underlying
model's `__init__` function.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "flax",
}
# Load config if we don't provide one
if config is None:
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
**kwargs,
)
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
# Load model
pretrained_path_with_subfolder = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
if os.path.isdir(pretrained_path_with_subfolder):
if from_pt:
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
)
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
# Check if pytorch weights exist instead
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`."
)
else:
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_path_with_subfolder}."
)
else:
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
" internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
if from_pt:
if is_torch_available():
from .modeling_utils import load_state_dict
else:
raise EnvironmentError(
"Can't load the model in PyTorch format because PyTorch is not installed. "
"Please, install PyTorch or use native Flax weights."
)
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)
# Step 2: Convert the weights
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
else:
try:
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
# flatten dicts
state = flatten_dict(state)
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
shape_state = flatten_dict(unfreeze(params_shape_tree))
missing_keys = required_params - set(state.keys())
unexpected_keys = set(state.keys()) - required_params
if missing_keys:
logger.warning(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
"Make sure to call model.init_weights to initialize the missing weights."
)
cls._missing_keys = missing_keys
for key in state.keys():
if key in shape_state and state[key].shape != shape_state[key].shape:
raise ValueError(
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
)
# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
del state[unexpected_key]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
return model, unflatten_dict(state)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
params: Union[Dict, FrozenDict],
is_main_process: bool = True,
push_to_hub: bool = False,
**kwargs,
):
"""
Save a model and its configuration file to a directory so that it can be reloaded using the
[`~FlaxModelMixin.from_pretrained`] class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
model_bytes = to_bytes(params)
f.write(model_bytes)
logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
from dataclasses import dataclass
from ..utils import BaseOutput
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - Flax general utilities."""
from pickle import UnpicklingError
import jax
import jax.numpy as jnp
import numpy as np
from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict
from ..utils import logging
logger = logging.get_logger(__name__)
#####################
# Flax => PyTorch #
#####################
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
try:
with open(model_file, "rb") as flax_state_f:
flax_state = from_bytes(None, flax_state_f.read())
except UnpicklingError as e:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"""Load flax checkpoints in a PyTorch model"""
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
# check if we have bf16 weights
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
# and bf16 is not fully supported in PT yet.
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_util.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)
pt_model.base_model_prefix = ""
flax_state_dict = flatten_dict(flax_state, sep=".")
pt_model_dict = pt_model.state_dict()
# keep track of unexpected & missing keys
unexpected_keys = []
missing_keys = set(pt_model_dict.keys())
for flax_key_tuple, flax_tensor in flax_state_dict.items():
flax_key_tuple_array = flax_key_tuple.split(".")
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
elif flax_key_tuple_array[-1] == "kernel":
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
flax_tensor = flax_tensor.T
elif flax_key_tuple_array[-1] == "scale":
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
if "time_embedding" not in flax_key_tuple_array:
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
flax_key_tuple_array[i] = (
flax_key_tuple_string.replace("_0", ".0")
.replace("_1", ".1")
.replace("_2", ".2")
.replace("_3", ".3")
.replace("_4", ".4")
.replace("_5", ".5")
.replace("_6", ".6")
.replace("_7", ".7")
.replace("_8", ".8")
.replace("_9", ".9")
)
flax_key = ".".join(flax_key_tuple_array)
if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError(
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
else:
# add weight to pytorch dict
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
# remove from missing keys
missing_keys.remove(flax_key)
else:
# weight is not expected by PyTorch model
unexpected_keys.append(flax_key)
pt_model.load_state_dict(pt_model_dict)
# re-transform missing_keys to list
missing_keys = list(missing_keys)
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the Flax model were not used when initializing the PyTorch model"
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
" FlaxBertForSequenceClassification model)."
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
" use it for predictions and inference."
)
return pt_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment