Unverified Commit b4d3f50b authored by Andrea Ferretti's avatar Andrea Ferretti Committed by GitHub
Browse files

feat: expose norm1 layer to support TeaCache (#234)



* feat: expose norm1 layer to support TeaCache

* feat: add TeaCache example

* feat: add idx as optional parameter

* chore: rename function

* refactor: move TeaCache decorator into example script

* test: add a test for the combination of Nunchaku with TeaCache

* feat: expose norm1 layer to support TeaCache

* feat: add TeaCache example

* feat: add idx as optional parameter

* chore: rename function

* refactor: move TeaCache decorator into example script

* test: add a test for the combination of Nunchaku with TeaCache

* fix: make tests run on low memory hardware

* fix: ensure that memory is correctly released between tests

* fix: avoid moving pipeline to device prematurely

* gpu memory does not release

* need to figure out a way to get compatible with offloading

* wrap up the teacache

---------
Co-authored-by: default avatarmuyangli <lmxyy1999@foxmail.com>
parent 17ddd2d9
import time
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
start_time = time.time()
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True):
image = pipeline(
"A cat holding a sign that says hello world",
num_inference_steps=50,
guidance_scale=3.5,
height=1024,
width=1024,
generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
end_time = time.time()
print(f"Time taken: {(end_time - start_time)} seconds")
image.save(f"flux.1-dev-{precision}-tc.png")
from types import MethodType
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import logging
from diffusers.utils.constants import USE_PEFT_BACKEND
from diffusers.utils.import_utils import is_torch_version
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
from ..models.transformers import NunchakuFluxTransformer2dModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0) -> Callable:
def teacache_forward(
self: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_projections: torch.Tensor,
timestep: torch.LongTensor,
img_ids: torch.Tensor,
txt_ids: torch.Tensor,
guidance: torch.Tensor,
joint_attention_kwargs: Optional[dict[str, Any]] = None,
controlnet_block_samples: Optional[torch.Tensor] = None,
controlnet_single_block_samples: Optional[torch.Tensor] = None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000 # type: ignore
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) # type: ignore
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_) # type: ignore
if self.cnt == 0 or self.cnt == num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == num_steps:
self.cnt = 0
ckpt_kwargs: dict[str, Any]
if self.cnt > skip_steps:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output: torch.FloatTensor = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return output
return Transformer2DModelOutput(sample=output)
return teacache_forward
# A context manager to add teacache support to a block of code
# When the context manager is applied, the model passed to the context manager is modified
# to support teacache
class TeaCache:
def __init__(
self,
model: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
num_steps: int = 50,
rel_l1_thresh: float = 0.6,
skip_steps: int = 0,
enabled: bool = True,
) -> None:
self.model = model
self.num_steps = num_steps
self.rel_l1_thresh = rel_l1_thresh
self.skip_steps = skip_steps
self.enabled = enabled
self.previous_model_forward = self.model.forward
def __enter__(self) -> "TeaCache":
if self.enabled:
# self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore
self.model.forward = MethodType(
make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps), self.model
)
self.model.cnt = 0
self.model.accumulated_rel_l1_distance = 0
self.model.previous_modulated_input = None
self.model.previous_residual = None
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.enabled:
self.model.forward = self.previous_model_forward
del self.model.cnt
del self.model.accumulated_rel_l1_distance
del self.model.previous_modulated_input
del self.model.previous_residual
......@@ -133,6 +133,22 @@ public:
return hidden_states;
}
// expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> norm_one_forward(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb
) {
AdaLayerNormZero::Output result = net->transformer_blocks.at(idx)->norm1.forward(from_torch(hidden_states), from_torch(temb));
return {
to_torch(result.x),
to_torch(result.gate_msa),
to_torch(result.shift_mlp),
to_torch(result.scale_mlp),
to_torch(result.gate_mlp)
};
}
// must be called after loading lora
// skip specific ranks in W4A4 layers
......
......@@ -48,6 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
......
......@@ -183,6 +183,14 @@ class NunchakuFluxTransformerBlocks(nn.Module):
def __del__(self):
self.m.reset()
def norm1(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return self.m.norm_one_forward(idx, hidden_states, emb)
## copied from diffusers 0.30.3
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
......
......@@ -117,11 +117,11 @@ public:
const int dim_head;
const int num_heads;
const bool context_pre_only;
AdaLayerNormZero norm1;
AttentionImpl attnImpl = AttentionImpl::FlashAttention2;
private:
AdaLayerNormZero norm1;
AdaLayerNormZero norm1_context;
GEMM qkv_proj;
GEMM qkv_proj_context;
......
......@@ -8,7 +8,7 @@ from .utils import run_test
@pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144),
(0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.161),
],
)
def test_flux_dev_cache(
......
......@@ -8,11 +8,11 @@ from .utils import run_test
@pytest.mark.parametrize(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.144),
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.165),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.144),
],
)
def test_flux_dev_cache(
def test_flux_dev_double_fb_cache(
use_double_fb_cache: bool,
residual_diff_threshold_multi: float,
residual_diff_threshold_single: float,
......
import gc
import os
import subprocess
import pytest
import torch
EXAMPLES_DIR = "./examples"
......@@ -10,6 +12,8 @@ example_scripts = [f for f in os.listdir(EXAMPLES_DIR) if f.endswith(".py") and
@pytest.mark.parametrize("script_name", example_scripts)
def test_example_script_runs(script_name):
gc.collect()
torch.cuda.empty_cache()
script_path = os.path.join(EXAMPLES_DIR, script_name)
result = subprocess.run(["python", script_path], capture_output=True, text=True)
print(f"Running {script_path} -> Return code: {result.returncode}")
......
......@@ -14,7 +14,7 @@ from .utils import run_test
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
],
)
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
def test_flux_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
height=height,
......
import gc
import os
import pytest
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision, is_turing
from .utils import already_generate, compute_lpips, offload_pipeline
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,num_inference_steps,prompt,name,seed,threshold,expected_lpips",
[
(
1024,
1024,
30,
"A cat holding a sign that says hello world",
"cat",
0,
0.6,
0.363 if get_precision() == "int4" else 0.363,
),
(
512,
2048,
25,
"The brown fox jumps over the lazy dog",
"fox",
1234,
0.7,
0.349 if get_precision() == "int4" else 0.349,
),
(
1024,
768,
50,
"A scene from the Titanic movie featuring the Muppets",
"muppets",
42,
0.3,
0.360 if get_precision() == "int4" else 0.495,
),
(
1024,
768,
50,
"A crystal ball showing a waterfall",
"waterfall",
23,
0.6,
0.226 if get_precision() == "int4" else 0.226,
),
],
)
def test_flux_teacache(
height: int,
width: int,
num_inference_steps: int,
prompt: str,
name: str,
seed: int,
threshold: float,
expected_lpips: float,
):
gc.collect()
torch.cuda.empty_cache()
device = torch.device("cuda")
precision = get_precision()
ref_root = os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref"))
results_dir_16_bit = os.path.join(ref_root, "bf16", "flux.1-dev", "teacache", name)
results_dir_4_bit = os.path.join("test_results", precision, "flux.1-dev", "teacache", name)
os.makedirs(results_dir_16_bit, exist_ok=True)
os.makedirs(results_dir_4_bit, exist_ok=True)
# First, generate results with the 16-bit model
if not already_generate(results_dir_16_bit, 1):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
# Possibly offload the model to CPU when GPU memory is scarce
pipeline = offload_pipeline(pipeline)
result = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=torch.Generator(device=device).manual_seed(seed),
).images[0]
result.save(os.path.join(results_dir_16_bit, f"{name}_{seed}.png"))
# Clean up the 16-bit model
del pipeline.transformer
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info() # bytes
print(f"After 16-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
# Then, generate results with the 4-bit model
if not already_generate(results_dir_4_bit, 1):
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
with torch.inference_mode():
with TeaCache(
model=pipeline.transformer, num_steps=num_inference_steps, rel_l1_thresh=threshold, enabled=True
):
result = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=torch.Generator(device=device).manual_seed(seed),
).images[0]
result.save(os.path.join(results_dir_4_bit, f"{name}_{seed}.png"))
# Clean up the 4-bit model
del pipeline
del transformer
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info() # bytes
print(f"After 4-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_16_bit, results_dir_4_bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.1
......@@ -9,11 +9,11 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.118, 2),
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2),
(1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
],
)
def test_int4_schnell(
def test_flux_schnell(
height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float, batch_size: int
):
run_test(
......
......@@ -200,8 +200,6 @@ def run_test(
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**2)
if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
......@@ -211,27 +209,7 @@ def run_test(
)
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
if gpu_memory > 36 * 1024:
pipeline = pipeline.to("cuda")
elif gpu_memory < 26 * 1024:
pipeline.transformer.enable_group_offload(
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
)
if pipeline.text_encoder is not None:
pipeline.text_encoder.to("cuda")
if pipeline.text_encoder_2 is not None:
apply_group_offloading(
pipeline.text_encoder_2,
onload_device=torch.device("cuda"),
offload_type="block_level",
num_blocks_per_group=2,
)
pipeline.vae.to("cuda")
else:
pipeline.enable_model_cpu_offload()
pipeline = offload_pipeline(pipeline)
run_pipeline(
batch_size=batch_size,
......@@ -343,3 +321,34 @@ def run_test(
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.1
def offload_pipeline(pipeline: FluxPipeline) -> FluxPipeline:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**2)
device = torch.device("cuda")
cpu = torch.device("cpu")
if gpu_memory > 36 * 1024:
pipeline = pipeline.to(device)
elif gpu_memory < 26 * 1024:
pipeline.transformer.enable_group_offload(
onload_device=device,
offload_device=cpu,
offload_type="leaf_level",
use_stream=True,
)
if pipeline.text_encoder is not None:
pipeline.text_encoder.to(device)
if pipeline.text_encoder_2 is not None:
apply_group_offloading(
pipeline.text_encoder_2,
onload_device=device,
offload_type="block_level",
num_blocks_per_group=2,
)
pipeline.vae.to(device)
else:
pipeline.enable_model_cpu_offload()
return pipeline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment