Unverified Commit 25383861 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

test: use CPU offload to save GPU memory in multi-batch test (#344)

* feat: support FP8 LoRAs (#342)

* feat: support FP8 LoRAs

* fix the int4 expected lpips

* test: use cpu offload to save gpu memory in multi-batch test
parent 37a27712
......@@ -14,6 +14,12 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = {k: v for k, v in input_lora.items()}
### convert the FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
if alphas is not None and len(alphas) > 0:
......
......@@ -34,7 +34,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
"fox",
1234,
0.7,
0.349 if get_precision() == "int4" else 0.349,
0.417 if get_precision() == "int4" else 0.349,
),
(
1024,
......
......@@ -11,7 +11,7 @@ from .utils import run_test
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2),
(1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
(1920, 1080, "flashattn2", True, 0.160 if get_precision() == "int4" else 0.123, 4),
],
)
def test_flux_schnell(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment