Unverified Commit 25383861 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

test: use CPU offload to save GPU memory in multi-batch test (#344)

* feat: support FP8 LoRAs (#342)

* feat: support FP8 LoRAs

* fix the int4 expected lpips

* test: use cpu offload to save gpu memory in multi-batch test
parent 37a27712
...@@ -14,6 +14,12 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N ...@@ -14,6 +14,12 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
tensors = load_state_dict_in_safetensors(input_lora, device="cpu") tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else: else:
tensors = {k: v for k, v in input_lora.items()} tensors = {k: v for k, v in input_lora.items()}
### convert the FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True) new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
if alphas is not None and len(alphas) > 0: if alphas is not None and len(alphas) > 0:
......
...@@ -34,7 +34,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline ...@@ -34,7 +34,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
"fox", "fox",
1234, 1234,
0.7, 0.7,
0.349 if get_precision() == "int4" else 0.349, 0.417 if get_precision() == "int4" else 0.349,
), ),
( (
1024, 1024,
......
...@@ -11,7 +11,7 @@ from .utils import run_test ...@@ -11,7 +11,7 @@ from .utils import run_test
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size", "height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[ [
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2), (1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2),
(1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4), (1920, 1080, "flashattn2", True, 0.160 if get_precision() == "int4" else 0.123, 4),
], ],
) )
def test_flux_schnell( def test_flux_schnell(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment