Unverified Commit eb901251 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: async CPU offloading for Python backend (#624)

* tmp

* update

* update

* finished the offloading impl

* the offloading is buggy

* update utils

* the offloading is still buggy

* update

* correctness and speedup done; need to check the vram overhead

* done

* final debugging

* update

* update

* correct now

* fix

* update

* use per-layer offloading

* fix the offloading on 5090

* support setting the num_blocks_on_gpu

* change the import name
parent ea99072a
import math import math
import torch import torch
from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import FlowMatchEulerDiscreteScheduler, QwenImagePipeline
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from nunchaku.pipeline.pipeline_qwenimage import NunchakuQwenImagePipeline from nunchaku.utils import get_gpu_memory, get_precision
from nunchaku.utils import get_precision
# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10 # From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
scheduler_config = { scheduler_config = {
...@@ -27,7 +26,7 @@ scheduler_config = { ...@@ -27,7 +26,7 @@ scheduler_config = {
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
num_inference_steps = 4 # you can also use the 8-step model to improve the quality num_inference_steps = 4 # you can also use the 8-step model to improve the quality
rank = 32 # you can also use the r128 or 8-step model to improve the quality rank = 32 # you can also use the rank=128 model to improve the quality
model_paths = { model_paths = {
4: f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image-lightningv1.0-4steps.safetensors", 4: f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image-lightningv1.0-4steps.safetensors",
8: f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image-lightningv1.1-8steps.safetensors", 8: f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image-lightningv1.1-8steps.safetensors",
...@@ -35,10 +34,18 @@ model_paths = { ...@@ -35,10 +34,18 @@ model_paths = {
# Load the model # Load the model
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(model_paths[num_inference_steps]) transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(model_paths[num_inference_steps])
pipe = NunchakuQwenImagePipeline.from_pretrained( pipe = QwenImagePipeline.from_pretrained(
"Qwen/Qwen-Image", transformer=transformer, scheduler=scheduler, torch_dtype=torch.bfloat16 "Qwen/Qwen-Image", transformer=transformer, scheduler=scheduler, torch_dtype=torch.bfloat16
) )
if get_gpu_memory() > 18:
pipe.enable_model_cpu_offload()
else:
# use per-layer offloading for low VRAM. This only requires 3-4GB of VRAM.
transformer.set_offload(True)
pipe._exclude_from_cpu_offload.append("transformer")
pipe.enable_sequential_cpu_offload()
prompt = """Bookstore window display. A sign displays “New Arrivals This Week”. Below, a shelf tag with the text “Best-Selling Novels Here”. To the side, a colorful poster advertises “Author Meet And Greet on Saturday” with a central portrait of the author. There are four books on the bookshelf, namely “The light between worlds” “When stars are scattered” “The slient patient” “The night circus”""" prompt = """Bookstore window display. A sign displays “New Arrivals This Week”. Below, a shelf tag with the text “Best-Selling Novels Here”. To the side, a colorful poster advertises “Author Meet And Greet on Saturday” with a central portrait of the author. There are four books on the bookshelf, namely “The light between worlds” “When stars are scattered” “The slient patient” “The night circus”"""
negative_prompt = " " negative_prompt = " "
image = pipe( image = pipe(
......
import torch import torch
from diffusers import QwenImagePipeline
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from nunchaku.pipeline.pipeline_qwenimage import NunchakuQwenImagePipeline from nunchaku.utils import get_gpu_memory, get_precision
from nunchaku.utils import get_precision
model_name = "Qwen/Qwen-Image" model_name = "Qwen/Qwen-Image"
rank = 32 # you can also use rank=128 model to improve the quality
# Load the model # Load the model
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained( transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r32-qwen-image.safetensors" f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image.safetensors"
) # you can also use r128 model to improve the quality )
# currently, you need to use this pipeline to offload the model to CPU # currently, you need to use this pipeline to offload the model to CPU
pipe = NunchakuQwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16) pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16)
if get_gpu_memory() > 18:
pipe.enable_model_cpu_offload()
else:
# use per-layer offloading for low VRAM. This only requires 3-4GB of VRAM.
transformer.set_offload(True)
pipe._exclude_from_cpu_offload.append("transformer")
pipe.enable_sequential_cpu_offload()
positive_magic = { positive_magic = {
"en": "Ultra HD, 4K, cinematic composition.", # for english prompt, "en": "Ultra HD, 4K, cinematic composition.", # for english prompt,
...@@ -32,4 +41,4 @@ image = pipe( ...@@ -32,4 +41,4 @@ image = pipe(
true_cfg_scale=4.0, true_cfg_scale=4.0,
).images[0] ).images[0]
image.save("qwen-image-r128.png") image.save(f"qwen-image-r{rank}.png")
...@@ -23,13 +23,8 @@ when input changes are minimal. Supports SANA and Flux architectures. ...@@ -23,13 +23,8 @@ when input changes are minimal. Supports SANA and Flux architectures.
import torch import torch
from torch import nn from torch import nn
from nunchaku.caching.fbcache import ( from ..utils import pad_tensor
apply_prev_hidden_states_residual, from .fbcache import apply_prev_hidden_states_residual, check_and_apply_cache, get_can_use_cache, set_buffer
check_and_apply_cache,
get_can_use_cache,
set_buffer,
)
from nunchaku.models.transformers.utils import pad_tensor
num_transformer_blocks = 19 # FIXME num_transformer_blocks = 19 # FIXME
num_single_transformer_blocks = 38 # FIXME num_single_transformer_blocks = 38 # FIXME
......
...@@ -27,9 +27,9 @@ from typing import Any, Dict, Optional, Union ...@@ -27,9 +27,9 @@ from typing import Any, Dict, Optional, Union
import torch import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_outputs import Transformer2DModelOutput
from nunchaku.caching.fbcache import check_and_apply_cache from ..models.embeddings import pack_rotemb
from nunchaku.models.embeddings import pack_rotemb from ..utils import pad_tensor
from nunchaku.models.transformers.utils import pad_tensor from .fbcache import check_and_apply_cache
def cached_forward_v2( def cached_forward_v2(
......
...@@ -36,6 +36,7 @@ void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K ...@@ -36,6 +36,7 @@ void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D] std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens) { int attn_tokens) {
TorchOpContext ctx;
spdlog::trace("running gemm_w4a4: "); spdlog::trace("running gemm_w4a4: ");
auto getTensor = [](std::optional<torch::Tensor> &t) { auto getTensor = [](std::optional<torch::Tensor> &t) {
...@@ -87,6 +88,7 @@ void quantize_w4a4_act_fuse_lora(std::optional<torch::Tensor> input, ...@@ -87,6 +88,7 @@ void quantize_w4a4_act_fuse_lora(std::optional<torch::Tensor> input,
std::optional<torch::Tensor> smooth, std::optional<torch::Tensor> smooth,
bool fuse_glu, bool fuse_glu,
bool fp4) { bool fp4) {
TorchOpContext ctx;
spdlog::trace("running quantize_w4a4_act_fuse_lora: "); spdlog::trace("running quantize_w4a4_act_fuse_lora: ");
...@@ -114,6 +116,7 @@ void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] ...@@ -114,6 +116,7 @@ void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale) { float scale) {
TorchOpContext ctx;
nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale); nunchaku::kernels::attention_fp16(from_torch(q), from_torch(k), from_torch(v), from_torch(o), scale);
} }
...@@ -125,6 +128,7 @@ torch::Tensor gemv_awq(torch::Tensor _in_feats, ...@@ -125,6 +128,7 @@ torch::Tensor gemv_awq(torch::Tensor _in_feats,
int64_t n, int64_t n,
int64_t k, int64_t k,
int64_t group_size) { int64_t group_size) {
TorchOpContext ctx;
Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()), Tensor result = ::gemv_awq(from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()), from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()), from_torch(_scaling_factors.contiguous()),
...@@ -147,6 +151,7 @@ gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_ ...@@ -147,6 +151,7 @@ gemm_awq(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_
from_torch(_scaling_factors.contiguous()), from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())); from_torch(_zeros.contiguous()));
TorchOpContext ctx;
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy) // TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
......
...@@ -15,7 +15,8 @@ from safetensors import safe_open ...@@ -15,7 +15,8 @@ from safetensors import safe_open
from torch import nn from torch import nn
from nunchaku.caching.utils import FluxCachedTransformerBlocks, check_and_apply_cache from nunchaku.caching.utils import FluxCachedTransformerBlocks, check_and_apply_cache
from nunchaku.models.transformers.utils import pad_tensor
from ...utils import pad_tensor
num_transformer_blocks = 19 # FIXME num_transformer_blocks = 19 # FIXME
num_single_transformer_blocks = 38 # FIXME num_single_transformer_blocks = 38 # FIXME
......
...@@ -22,8 +22,8 @@ from ..._C import QuantizedFluxModel ...@@ -22,8 +22,8 @@ from ..._C import QuantizedFluxModel
from ..._C import utils as cutils from ..._C import utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format from ...lora.flux.utils import is_nunchaku_format
from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors, pad_tensor
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin
SVD_RANK = 32 SVD_RANK = 32
......
...@@ -15,14 +15,14 @@ from huggingface_hub import utils ...@@ -15,14 +15,14 @@ from huggingface_hub import utils
from torch.nn import GELU from torch.nn import GELU
from ...ops.fused import fused_gelu_mlp from ...ops.fused import fused_gelu_mlp
from ...utils import get_precision from ...utils import get_precision, pad_tensor
from ..attention import NunchakuBaseAttention, NunchakuFeedForward from ..attention import NunchakuBaseAttention, NunchakuFeedForward
from ..attention_processors.flux import NunchakuFluxFA2Processor, NunchakuFluxFP16AttnProcessor from ..attention_processors.flux import NunchakuFluxFA2Processor, NunchakuFluxFP16AttnProcessor
from ..embeddings import NunchakuFluxPosEmbed, pack_rotemb from ..embeddings import NunchakuFluxPosEmbed, pack_rotemb
from ..linear import SVDQW4A4Linear from ..linear import SVDQW4A4Linear
from ..normalization import NunchakuAdaLayerNormZero, NunchakuAdaLayerNormZeroSingle from ..normalization import NunchakuAdaLayerNormZero, NunchakuAdaLayerNormZeroSingle
from ..utils import fuse_linears from ..utils import fuse_linears
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin
class NunchakuFluxAttention(NunchakuBaseAttention): class NunchakuFluxAttention(NunchakuBaseAttention):
......
import gc
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn
import torch import torch
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_qwenimage import ( from diffusers.models.transformers.transformer_qwenimage import (
QwenEmbedRope, QwenEmbedRope,
QwenImageTransformer2DModel, QwenImageTransformer2DModel,
...@@ -16,7 +19,7 @@ from ...utils import get_precision ...@@ -16,7 +19,7 @@ from ...utils import get_precision
from ..attention import NunchakuBaseAttention, NunchakuFeedForward from ..attention import NunchakuBaseAttention, NunchakuFeedForward
from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor
from ..linear import AWQW4A16Linear, SVDQW4A4Linear from ..linear import AWQW4A16Linear, SVDQW4A4Linear
from ..utils import fuse_linears from ..utils import CPUOffloadManager, fuse_linears
from .utils import NunchakuModelLoaderMixin from .utils import NunchakuModelLoaderMixin
...@@ -206,9 +209,16 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock): ...@@ -206,9 +209,16 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin): class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin):
def __init__(self, *args, **kwargs):
self.offload = kwargs.pop("offload", False)
self.offload_manager = None
self._is_initialized = False
super().__init__(*args, **kwargs)
def _patch_model(self, **kwargs): def _patch_model(self, **kwargs):
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs) self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs)
self._is_initialized = True
return self return self
@classmethod @classmethod
...@@ -217,9 +227,6 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -217,9 +227,6 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
device = kwargs.get("device", "cpu") device = kwargs.get("device", "cpu")
offload = kwargs.get("offload", False) offload = kwargs.get("offload", False)
if offload:
raise NotImplementedError("Offload is not supported for FluxTransformer2DModelV2")
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
if isinstance(pretrained_model_name_or_path, str): if isinstance(pretrained_model_name_or_path, str):
...@@ -259,5 +266,128 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -259,5 +266,128 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
if m.wtscale is not None: if m.wtscale is not None:
m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0) m.wtscale = model_state_dict.pop(f"{n}.wtscale", 1.0)
transformer.load_state_dict(model_state_dict) transformer.load_state_dict(model_state_dict)
transformer.set_offload(offload)
return transformer return transformer
def set_offload(self, offload: bool, **kwargs):
if offload == self.offload:
# nothing changed, just return
return
self.offload = offload
if offload:
self.offload_manager = CPUOffloadManager(
self.transformer_blocks,
use_pin_memory=kwargs.get("use_pin_memory", True),
on_gpu_modules=[
self.img_in,
self.txt_in,
self.txt_norm,
self.time_text_embed,
self.norm_out,
self.proj_out,
],
num_blocks_on_gpu=kwargs.get("num_blocks_on_gpu", 1),
)
else:
self.offload_manager = None
gc.collect()
torch.cuda.empty_cache()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
device = hidden_states.device
if self.offload:
self.offload_manager.set_device(device)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
if self.offload:
self.offload_manager.initialize()
compute_stream = self.offload_manager.compute_stream
else:
compute_stream = torch.cuda.current_stream()
for block_idx, block in enumerate(self.transformer_blocks):
with torch.cuda.stream(compute_stream):
if self.offload:
block = self.offload_manager.get_block(block_idx)
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
if self.offload:
self.offload_manager.step()
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
torch.cuda.empty_cache()
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def to(self, *args, **kwargs):
"""
Overwrite the default .to() method.
If self.offload is True, avoid moving the model to GPU.
"""
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs
# Try converting arguments to torch.device in case they are passed as strings
for arg in args:
if not isinstance(arg, str):
continue
try:
torch.device(arg)
device_arg_or_kwarg_present = True
except RuntimeError:
pass
if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
if dtype_present_in_args and self._is_initialized:
raise ValueError(
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
)
if self.offload:
if device_arg_or_kwarg_present:
warn("Skipping moving the model to GPU as offload is enabled", UserWarning)
return self
return super(type(self), self).to(*args, **kwargs)
...@@ -6,14 +6,13 @@ import json ...@@ -6,14 +6,13 @@ import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Optional
import torch import torch
from diffusers import __version__ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from torch import nn from torch import nn
from nunchaku.utils import ceil_divide, load_state_dict_in_safetensors from ...utils import load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO) # Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper() log_level = os.getenv("LOG_LEVEL", "INFO").upper()
...@@ -146,37 +145,3 @@ class NunchakuModelLoaderMixin: ...@@ -146,37 +145,3 @@ class NunchakuModelLoaderMixin:
with torch.device("meta"): with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16)) transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, unquantized_part_path, transformer_block_path return transformer, unquantized_part_path, transformer_block_path
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor | None:
"""
Pad a tensor along a given dimension to the next multiple of a specified value.
Parameters
----------
tensor : torch.Tensor or None
Input tensor. If None, returns None.
multiples : int
Pad to this multiple. If <= 1, no padding is applied.
dim : int
Dimension along which to pad.
fill : Any, optional
Value to use for padding (default: 0).
Returns
-------
torch.Tensor or None
The padded tensor, or None if input was None.
"""
if multiples <= 1:
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if shape[dim] % multiples == 0:
return tensor
shape[dim] = ceil_divide(shape[dim], multiples) * multiples
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
import copy
import torch
from torch import nn from torch import nn
from ..utils import copy_params_into
def fuse_linears(linears: list[nn.Linear]) -> nn.Linear: def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
assert len(linears) > 0 assert len(linears) > 0
...@@ -16,3 +21,118 @@ def fuse_linears(linears: list[nn.Linear]) -> nn.Linear: ...@@ -16,3 +21,118 @@ def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
dtype=linears[0].weight.dtype, dtype=linears[0].weight.dtype,
device=linears[0].weight.device, device=linears[0].weight.device,
) )
class CPUOffloadManager:
"""Generic manager for per-transformer-block CPU offloading with async memory operations.
This class can be used with any transformer model that has a list of transformer blocks.
It provides memory-efficient processing by keeping only the current block on GPU.
"""
def __init__(
self,
blocks: list[nn.Module],
device: str | torch.device = torch.device("cuda"),
use_pin_memory: bool = True,
on_gpu_modules: list[nn.Module] = [],
num_blocks_on_gpu: int = 1,
empty_cache_freq: int = 0,
):
self.blocks = blocks
self.use_pin_memory = use_pin_memory
self.on_gpu_modules = on_gpu_modules
self.num_blocks_on_gpu = num_blocks_on_gpu
assert self.num_blocks_on_gpu > 0
# Two streams: one for compute, one for memory operations, will be initialized in set_device
self.compute_stream = None
self.memory_stream = None
self.compute_done = torch.cuda.Event(blocking=False)
self.memory_done = torch.cuda.Event(blocking=False)
self.buffer_blocks = [copy.deepcopy(blocks[0]), copy.deepcopy(blocks[0])]
self.device = None
self.set_device(device)
self.current_block_idx = 0
self.forward_counter = 0
self.empty_cache_freq = empty_cache_freq
def set_device(self, device: torch.device | str, force: bool = False):
if isinstance(device, str):
device = torch.device(device)
assert device.type == "cuda"
if self.device == device and not force:
return
self.device = device
self.compute_stream = torch.cuda.Stream(device=device)
self.memory_stream = torch.cuda.Stream(device=device)
for block in self.buffer_blocks:
block.to(device)
for module in self.on_gpu_modules:
module.to(device)
for i, block in enumerate(self.blocks):
if i < self.num_blocks_on_gpu:
block.to(device)
else:
block.to("cpu")
if self.use_pin_memory:
for p in block.parameters(recurse=True):
p.data = p.data.pin_memory()
for b in block.buffers(recurse=True):
b.data = b.data.pin_memory()
def load_block(self, block_idx: int, non_blocking: bool = True):
"""Move a transformer block to GPU."""
# if the block is already on GPU, don't load it to the buffer
if block_idx < self.num_blocks_on_gpu:
return
# if there are blocks on GPU, don't load the first block to the buffer again
if block_idx >= len(self.blocks):
return
block = self.blocks[block_idx]
copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
def step(self, next_stream: torch.cuda.Stream | None = None):
"""Move to the next block, triggering preload operations."""
next_compute_done = torch.cuda.Event()
next_compute_done.record(self.compute_stream)
with torch.cuda.stream(self.memory_stream):
self.memory_stream.wait_event(self.compute_done)
self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block
next_memory_done = torch.cuda.Event()
next_memory_done.record(self.memory_stream)
self.memory_done = next_memory_done
self.compute_done = next_compute_done
self.current_block_idx += 1
if self.current_block_idx < len(self.blocks):
# get ready for the next compute
self.compute_stream.wait_event(self.memory_done)
else:
# ready to finish
if next_stream is None:
torch.cuda.current_stream().wait_event(self.compute_done)
else:
next_stream.wait_event(self.compute_done)
self.current_block_idx = 0
self.forward_counter += 1
if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0:
torch.cuda.empty_cache()
def get_block(self, block_idx: int | None = None) -> nn.Module:
if block_idx is None:
block_idx = self.current_block_idx
if block_idx < self.num_blocks_on_gpu:
return self.blocks[block_idx]
else:
return self.buffer_blocks[block_idx % 2]
def initialize(self, stream: torch.cuda.Stream | None = None):
if stream is None:
stream = torch.cuda.current_stream()
self.compute_done.record(stream)
self.memory_done.record(stream)
import gc
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers.pipelines.qwenimage.pipeline_qwenimage import (
QwenImagePipeline,
QwenImagePipelineOutput,
calculate_shift,
retrieve_timesteps,
)
class NunchakuQwenImagePipeline(QwenImagePipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
true_cfg_scale: float = 4.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = torch.device("cuda")
self.text_encoder.to(device)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
# Get latents from parent method (returns single tensor)
latents = super().prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# Generate latent_image_ids manually
seq_len = latents.shape[1] if len(latents.shape) > 1 else 0
latent_image_ids = torch.arange(seq_len, device=device, dtype=torch.long)
latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size * num_images_per_prompt, 1)
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
self.text_encoder.to("cpu")
gc.collect()
torch.cuda.empty_cache()
self.transformer.to(device)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self.transformer.to("cpu")
gc.collect()
torch.cuda.empty_cache()
self.vae.to(device)
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
self.vae.to("cpu")
gc.collect()
torch.cuda.empty_cache()
if not return_dict:
return (image,)
return QwenImagePipelineOutput(images=image)
...@@ -6,10 +6,46 @@ import hashlib ...@@ -6,10 +6,46 @@ import hashlib
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any
import safetensors import safetensors
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from torch import nn
def pad_tensor(tensor: torch.Tensor | None, multiples: int, dim: int, fill: Any = 0) -> torch.Tensor | None:
"""
Pad a tensor along a given dimension to the next multiple of a specified value.
Parameters
----------
tensor : torch.Tensor or None
Input tensor. If None, returns None.
multiples : int
Pad to this multiple. If <= 1, no padding is applied.
dim : int
Dimension along which to pad.
fill : Any, optional
Value to use for padding (default: 0).
Returns
-------
torch.Tensor or None
The padded tensor, or None if input was None.
"""
if multiples <= 1:
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if shape[dim] % multiples == 0:
return tensor
shape[dim] = ceil_divide(shape[dim], multiples) * multiples
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
def sha256sum(filepath: str | os.PathLike[str]) -> str: def sha256sum(filepath: str | os.PathLike[str]) -> str:
...@@ -295,3 +331,36 @@ def get_precision_from_quantization_config(quantization_config: dict) -> str: ...@@ -295,3 +331,36 @@ def get_precision_from_quantization_config(quantization_config: dict) -> str:
return "int4" return "int4"
else: else:
raise ValueError(f"Unsupported quantization dtype: {quantization_config['weight']['dtype']}") raise ValueError(f"Unsupported quantization dtype: {quantization_config['weight']['dtype']}")
def copy_params_into(src: nn.Module, dst: nn.Module, non_blocking: bool = True):
"""
Copy all parameters and buffers from a source module to a destination module.
Parameters
----------
src : nn.Module
Source module from which parameters and buffers are copied.
dst : nn.Module
Destination module to which parameters and buffers are copied.
non_blocking : bool, optional
If True, copies are performed asynchronously with respect to the host if possible (default: True).
Notes
-----
- The function assumes that `src` and `dst` have the same structure and number of parameters and buffers.
- All copying is performed under `torch.no_grad()` context to avoid tracking in autograd.
"""
with torch.no_grad():
for ps, pd in zip(src.parameters(), dst.parameters()):
pd.copy_(ps, non_blocking=non_blocking)
for bs, bd in zip(src.buffers(), dst.buffers()):
bd.copy_(bs, non_blocking=non_blocking)
for ms, md in zip(src.modules(), dst.modules()):
# wtscale is a special case which is a float on the CPU
if hasattr(ms, "wtscale"):
assert hasattr(md, "wtscale")
md.wtscale = ms.wtscale
else:
assert not hasattr(md, "wtscale")
...@@ -5,4 +5,4 @@ facexlib ...@@ -5,4 +5,4 @@ facexlib
onnxruntime onnxruntime
# ip-adapter # ip-adapter
timm timm
diffusers>=0.33.1 diffusers==0.35
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment