Unverified Commit 07f07563 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

chore: release v0.3.1

parents 7214300d ad92b16a
...@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers ...@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
from .packer import NunchakuWeightPacker from .packer import NunchakuWeightPacker
from .utils import is_nunchaku_format, pad from .utils import is_nunchaku_format, pad
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# region utilities # region utilities
......
...@@ -3,11 +3,13 @@ import logging ...@@ -3,11 +3,13 @@ import logging
import os import os
import re import re
from copy import deepcopy from copy import deepcopy
from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from ....utils import fetch_or_download
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
...@@ -227,6 +229,7 @@ def create_model( ...@@ -227,6 +229,7 @@ def create_model(
pretrained_text_model: str = None, pretrained_text_model: str = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
skip_list: list = [], skip_list: list = [],
pretrained_path: str | PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
): ):
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
if isinstance(device, str): if isinstance(device, str):
...@@ -239,8 +242,35 @@ def create_model( ...@@ -239,8 +242,35 @@ def create_model(
if model_cfg is not None: if model_cfg is not None:
logging.info(f"Loaded {model_name} model config.") logging.info(f"Loaded {model_name} model config.")
else: else:
logging.error(f"Model config for {model_name} not found; available models {list_models()}.") model_cfg = {
raise RuntimeError(f"Model config for {model_name} not found.") "embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": True,
"fusedLN": True,
"rope": True,
"pt_hw_seq_len": 16,
"intp_freq": True,
"naiveswiglu": True,
"subln": True,
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": False,
"fusedLN": True,
},
}
if "rope" in model_cfg.get("vision_cfg", {}): if "rope" in model_cfg.get("vision_cfg", {}):
if model_cfg["vision_cfg"]["rope"]: if model_cfg["vision_cfg"]["rope"]:
...@@ -270,12 +300,7 @@ def create_model( ...@@ -270,12 +300,7 @@ def create_model(
pretrained_cfg = {} pretrained_cfg = {}
if pretrained: if pretrained:
checkpoint_path = "" checkpoint_path = fetch_or_download(pretrained_path)
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path: if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
...@@ -379,6 +404,7 @@ def create_model_and_transforms( ...@@ -379,6 +404,7 @@ def create_model_and_transforms(
image_std: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
skip_list: list = [], skip_list: list = [],
pretrained_path: str | PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
): ):
model = create_model( model = create_model(
model_name, model_name,
...@@ -396,6 +422,7 @@ def create_model_and_transforms( ...@@ -396,6 +422,7 @@ def create_model_and_transforms(
pretrained_text_model=pretrained_text_model, pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir, cache_dir=cache_dir,
skip_list=skip_list, skip_list=skip_list,
pretrained_path=pretrained_path,
) )
image_mean = image_mean or getattr(model.visual, "image_mean", None) image_mean = image_mean or getattr(model.visual, "image_mean", None)
......
...@@ -24,6 +24,8 @@ def pulid_forward( ...@@ -24,6 +24,8 @@ def pulid_forward(
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
return_dict: bool = True, return_dict: bool = True,
controlnet_blocks_repeat: bool = False, controlnet_blocks_repeat: bool = False,
start_timestep: float | None = None,
end_timestep: float | None = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
""" """
Copied from diffusers.models.flux.transformer_flux.py Copied from diffusers.models.flux.transformer_flux.py
...@@ -53,6 +55,16 @@ def pulid_forward( ...@@ -53,6 +55,16 @@ def pulid_forward(
""" """
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
if timestep.numel() > 1:
timestep_float = timestep.flatten()[0].item()
else:
timestep_float = timestep.item()
if start_timestep is not None and start_timestep > timestep_float:
id_embeddings = None
if end_timestep is not None and end_timestep < timestep_float:
id_embeddings = None
timestep = timestep.to(hidden_states.dtype) * 1000 timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None: if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000 guidance = guidance.to(hidden_states.dtype) * 1000
......
...@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.id_weight = id_weight self.id_weight = id_weight
self.pulid_ca_idx = 0 self.pulid_ca_idx = 0
if self.id_embeddings is not None: if self.id_embeddings is not None:
self.set_residual_callback() self.set_pulid_residual_callback()
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
original_device = hidden_states.device original_device = hidden_states.device
...@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
) )
if self.id_embeddings is not None: if self.id_embeddings is not None:
self.reset_residual_callback() self.reset_pulid_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
...@@ -194,21 +194,21 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -194,21 +194,21 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def set_residual_callback(self): def set_pulid_residual_callback(self):
id_embeddings = self.id_embeddings id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx] pulid_ca_idx = [self.pulid_ca_idx]
id_weight = self.id_weight id_weight = self.id_weight
def callback(hidden_states): def callback(hidden_states):
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda")) ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states)
pulid_ca_idx[0] += 1 pulid_ca_idx[0] += 1
return ip return ip
self.callback_holder = callback self.callback_holder = callback
self.m.set_residual_callback(callback) self.m.set_residual_callback(callback)
def reset_residual_callback(self): def reset_pulid_residual_callback(self):
self.callback_holder = None self.callback_holder = None
self.m.set_residual_callback(None) self.m.set_residual_callback(None)
......
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py # Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import gc import gc
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import cv2 import cv2
...@@ -13,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput ...@@ -13,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import replace_example_docstring from diffusers.utils import replace_example_docstring
from facexlib.parsing import init_parsing_model from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis
from safetensors.torch import load_file
from torch import nn from torch import nn
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize from torchvision.transforms.functional import normalize, resize
...@@ -24,10 +27,54 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA ...@@ -24,10 +27,54 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from ..models.pulid.eva_clip import create_model_and_transforms from ..models.pulid.eva_clip import create_model_and_transforms
from ..models.pulid.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from ..models.pulid.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img
from ..models.transformers import NunchakuFluxTransformer2dModel
from ..utils import load_state_dict_in_safetensors, sha256sum
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
antelopev2_dirpath = Path(antelopev2_dirpath)
required_files = {
"1k3d68.onnx": "df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc",
"2d106det.onnx": "f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf",
"genderage.onnx": "4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb",
"glintr100.onnx": "4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf",
"scrfd_10g_bnkps.onnx": "5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91",
}
if not antelopev2_dirpath.is_dir():
logger.debug(f"Directory does not exist: {antelopev2_dirpath}")
return False
for filename, expected_hash in required_files.items():
filepath = antelopev2_dirpath / filename
if not filepath.exists():
logger.debug(f"Missing file: {filename}")
return False
if expected_hash != "<SKIP_HASH>" and not sha256sum(filepath) == expected_hash:
logger.debug(f"Hash mismatch for: {filename}")
return False
return True
class PuLIDPipeline(nn.Module): class PuLIDPipeline(nn.Module):
def __init__(self, dit, device, weight_dtype=torch.bfloat16, onnx_provider="gpu", *args, **kwargs): def __init__(
self,
dit: NunchakuFluxTransformer2dModel,
device: str | torch.device,
weight_dtype: str | torch.dtype = torch.bfloat16,
onnx_provider: str = "gpu",
pulid_path: str | os.PathLike[str] = "guozinan/PuLID/pulid_flux_v0.9.1.safetensors",
eva_clip_path: str | os.PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
insightface_dirpath: str | os.PathLike[str] | None = None,
facexlib_dirpath: str | os.PathLike[str] | None = None,
):
super().__init__() super().__init__()
self.device = device self.device = device
self.weight_dtype = weight_dtype self.weight_dtype = weight_dtype
...@@ -50,6 +97,11 @@ class PuLIDPipeline(nn.Module): ...@@ -50,6 +97,11 @@ class PuLIDPipeline(nn.Module):
# preprocessors # preprocessors
# face align and parsing # face align and parsing
if facexlib_dirpath is None:
facexlib_dirpath = Path(HUGGINGFACE_HUB_CACHE) / "facexlib"
facexlib_dirpath = Path(facexlib_dirpath)
self.face_helper = FaceRestoreHelper( self.face_helper = FaceRestoreHelper(
upscale_factor=1, upscale_factor=1,
face_size=512, face_size=512,
...@@ -57,11 +109,17 @@ class PuLIDPipeline(nn.Module): ...@@ -57,11 +109,17 @@ class PuLIDPipeline(nn.Module):
det_model="retinaface_resnet50", det_model="retinaface_resnet50",
save_ext="png", save_ext="png",
device=self.device, device=self.device,
model_rootpath=str(facexlib_dirpath),
) )
self.face_helper.face_parse = None self.face_helper.face_parse = None
self.face_helper.face_parse = init_parsing_model(model_name="bisenet", device=self.device) self.face_helper.face_parse = init_parsing_model(
model_name="bisenet", device=self.device, model_rootpath=str(facexlib_dirpath)
)
# clip-vit backbone # clip-vit backbone
model, _, _ = create_model_and_transforms("EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True) model, _, _ = create_model_and_transforms(
"EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True, pretrained_path=eva_clip_path
)
model = model.visual model = model.visual
self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype) self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype)
eva_transform_mean = getattr(self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN) eva_transform_mean = getattr(self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN)
...@@ -72,41 +130,51 @@ class PuLIDPipeline(nn.Module): ...@@ -72,41 +130,51 @@ class PuLIDPipeline(nn.Module):
eva_transform_std = (eva_transform_std,) * 3 eva_transform_std = (eva_transform_std,) * 3
self.eva_transform_mean = eva_transform_mean self.eva_transform_mean = eva_transform_mean
self.eva_transform_std = eva_transform_std self.eva_transform_std = eva_transform_std
# antelopev2 # antelopev2
snapshot_download("DIAMONIK7777/antelopev2", local_dir="models/antelopev2") if insightface_dirpath is None:
insightface_dirpath = Path(HUGGINGFACE_HUB_CACHE) / "insightface"
insightface_dirpath = Path(insightface_dirpath)
if insightface_dirpath is not None:
antelopev2_dirpath = insightface_dirpath / "models" / "antelopev2"
else:
antelopev2_dirpath = None
if antelopev2_dirpath is None or not check_antelopev2_dir(antelopev2_dirpath):
snapshot_download("DIAMONIK7777/antelopev2", local_dir=antelopev2_dirpath)
providers = ( providers = (
["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"] ["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"]
) )
self.app = FaceAnalysis(name="antelopev2", root=".", providers=providers) self.app = FaceAnalysis(name="antelopev2", root=insightface_dirpath, providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640)) self.app.prepare(ctx_id=0, det_size=(640, 640))
self.handler_ante = insightface.model_zoo.get_model("models/antelopev2/glintr100.onnx", providers=providers) self.handler_ante = insightface.model_zoo.get_model(
str(antelopev2_dirpath / "glintr100.onnx"), providers=providers
)
self.handler_ante.prepare(ctx_id=0) self.handler_ante.prepare(ctx_id=0)
gc.collect() # pulid model
torch.cuda.empty_cache() state_dict = load_state_dict_in_safetensors(pulid_path)
module_state_dict = {}
# other configs
self.debug_img_list = []
def load_pretrain(self, pretrain_path=None, version="v0.9.0"):
hf_hub_download("guozinan/PuLID", f"pulid_flux_{version}.safetensors", local_dir="models")
ckpt_path = f"models/pulid_flux_{version}.safetensors"
if pretrain_path is not None:
ckpt_path = pretrain_path
state_dict = load_file(ckpt_path)
state_dict_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
module = k.split(".")[0] module = k.split(".")[0]
state_dict_dict.setdefault(module, {}) module_state_dict.setdefault(module, {})
new_k = k[len(module) + 1 :] new_k = k[len(module) + 1 :]
state_dict_dict[module][new_k] = v module_state_dict[module][new_k] = v
for module in state_dict_dict: for module in module_state_dict:
print(f"loading from {module}") logging.debug(f"loading from {module}")
getattr(self, module).load_state_dict(state_dict_dict[module], strict=True) getattr(self, module).load_state_dict(module_state_dict[module], strict=True)
del state_dict del state_dict
del state_dict_dict del module_state_dict
gc.collect()
torch.cuda.empty_cache()
# other configs
self.debug_img_list = []
def to_gray(self, img): def to_gray(self, img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
...@@ -206,7 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline): ...@@ -206,7 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline):
pulid_device="cuda", pulid_device="cuda",
weight_dtype=torch.bfloat16, weight_dtype=torch.bfloat16,
onnx_provider="gpu", onnx_provider="gpu",
pretrained_model=None,
): ):
super().__init__( super().__init__(
scheduler=scheduler, scheduler=scheduler,
...@@ -232,7 +299,6 @@ class PuLIDFluxPipeline(FluxPipeline): ...@@ -232,7 +299,6 @@ class PuLIDFluxPipeline(FluxPipeline):
weight_dtype=self.weight_dtype, weight_dtype=self.weight_dtype,
onnx_provider=self.onnx_provider, onnx_provider=self.onnx_provider,
) )
self.pulid_model.load_pretrain(pretrained_model)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
......
import hashlib
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -7,6 +8,14 @@ import torch ...@@ -7,6 +8,14 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
def sha256sum(filepath: str | os.PathLike[str]) -> str:
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path: def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
path = Path(path) path = Path(path)
......
...@@ -837,11 +837,8 @@ Tensor FluxModel::forward(Tensor hidden_states, ...@@ -837,11 +837,8 @@ Tensor FluxModel::forward(Tensor hidden_states,
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]); hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} }
if (residual_callback && layer % 2 == 0) { if (residual_callback && layer % 2 == 0) {
Tensor cpu_input = hidden_states.copy(Device::cpu()); Tensor residual = residual_callback(hidden_states);
pybind11::gil_scoped_acquire gil; hidden_states = kernels::add(hidden_states, residual);
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
hidden_states = kernels::add(hidden_states, residual);
} }
} else { } else {
if (size_t(layer) == transformer_blocks.size()) { if (size_t(layer) == transformer_blocks.size()) {
...@@ -875,12 +872,9 @@ Tensor FluxModel::forward(Tensor hidden_states, ...@@ -875,12 +872,9 @@ Tensor FluxModel::forward(Tensor hidden_states,
size_t local_layer_idx = layer - transformer_blocks.size(); size_t local_layer_idx = layer - transformer_blocks.size();
if (residual_callback && local_layer_idx % 4 == 0) { if (residual_callback && local_layer_idx % 4 == 0) {
Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens); Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Tensor cpu_input = callback_input.copy(Device::cpu()); Tensor residual = residual_callback(callback_input);
pybind11::gil_scoped_acquire gil; auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Tensor cpu_output = residual_callback(cpu_input); slice = kernels::add(slice, residual);
Tensor residual = cpu_output.copy(Device::cuda());
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, residual);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice); hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
} }
} }
...@@ -919,6 +913,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer, ...@@ -919,6 +913,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
Tensor controlnet_block_samples, Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) { Tensor controlnet_single_block_samples) {
if (offload && layer > 0) {
if (layer < transformer_blocks.size()) {
transformer_blocks.at(layer)->loadLazyParams();
} else {
transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
}
}
if (layer < transformer_blocks.size()) { if (layer < transformer_blocks.size()) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward( std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
...@@ -954,6 +956,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer, ...@@ -954,6 +956,14 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice); hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
} }
if (offload && layer > 0) {
if (layer < transformer_blocks.size()) {
transformer_blocks.at(layer)->releaseLazyParams();
} else {
transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
}
}
return {hidden_states, encoder_hidden_states}; return {hidden_states, encoder_hidden_states};
} }
......
...@@ -189,6 +189,9 @@ public: ...@@ -189,6 +189,9 @@ public:
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor &)> residual_callback; std::function<Tensor(const Tensor &)> residual_callback;
bool isOffloadEnabled() const {
return offload;
}
private: private:
bool offload; bool offload;
......
import pytest import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
...@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024(): ...@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
cache_threshold=0, cache_threshold=0,
expected_lpips=0.310 if get_precision() == "int4" else 0.168, expected_lpips=0.310 if get_precision() == "int4" else 0.168,
) )
def test_kohya_lora():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
transformer.update_lora_params("mit-han-lab/nunchaku-test-models/hand_drawn_game.safetensors")
transformer.set_lora_strength(1)
prompt = (
"masterful impressionism oil painting titled 'the violinist', the composition follows the rule of thirds, "
"placing the violinist centrally in the frame. the subject is a young woman with fair skin and light blonde "
"hair is styled in a long, flowing hairstyle with natural waves. she is dressed in an opulent, "
"luxurious silver silk gown with a high waist and intricate gold detailing along the bodice. "
"the gown's texture is smooth and reflective. she holds a violin under her chin, "
"her right hand poised to play, and her left hand supporting the neck of the instrument. "
"she wears a delicate gold necklace with small, sparkling gemstones that catch the light. "
"her beautiful eyes focused on the viewer. the background features an elegantly furnished room "
"with classical late 19th century decor. to the left, there is a large, ornate portrait of "
"a man in a dark suit, set in a gilded frame. below this, a wooden desk with a closed book. "
"to the right, a red upholstered chair with a wooden frame is partially visible. "
"the room is bathed in natural light streaming through a window with red curtains, "
"creating a warm, inviting atmosphere. the lighting highlights the violinist, "
"casting soft shadows that enhance the depth and realism of the scene, highly aesthetic, "
"harmonious colors, impressioniststrokes, "
"<lora:style-impressionist_strokes-flux-by_daalis:1.0> <lora:image_upgrade-flux-by_zeronwo7829:1.0>"
)
image = pipeline(prompt, num_inference_steps=20, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-{precision}-1.png")
import gc
from types import MethodType from types import MethodType
import numpy as np import numpy as np
...@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing ...@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid(): def test_flux_dev_pulid():
gc.collect()
torch.cuda.empty_cache()
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment