Unverified Commit 5182f8f8 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: single-file model loading (#413)

* add a script to merge models

* finished

* try to merge t5

* merge the config into meta files

* rewrite the t5 model loading

* consider the case of subfolder

* merged the qencoder files

* make the linter happy and fix the tests

* pass tests

* add deprecation messages

* add a script to merge models

* schnell script runnable

* update sana

* modify the model paths

* fix the model paths

* style: make the linter happy

* remove the debugging assertion

* chore: fix the qencoder lpips

* fix the lpips
parent 8401d290
...@@ -31,7 +31,7 @@ Join our user groups on [**Slack**](https://join.slack.com/t/nunchaku/shared_inv ...@@ -31,7 +31,7 @@ Join our user groups on [**Slack**](https://join.slack.com/t/nunchaku/shared_inv
<summary>More</summary> <summary>More</summary>
- **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!** - **[2025-02-04]** **🚀 4-bit [FLUX.1-tools](https://blackforestlabs.ai/flux-1-tools/) is here!** Enjoy a **2-3× speedup** over the original models. Check out the [examples](./examples) for usage. **ComfyUI integration is coming soon!**
- **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](./examples/sana_1600m_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)! - **[2025-01-23]** 🚀 **4-bit [SANA](https://nvlabs.github.io/Sana/) support is here!** Experience a 2-3× speedup compared to the 16-bit model. Check out the [usage example](examples/sana1.6b_pag.py) and the [deployment guide](app/sana/t2i) for more details. Explore our live demo at [svdquant.mit.edu](https://svdquant.mit.edu)!
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**! - **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007) has been accepted to **ICLR 2025**!
- **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [mit-han-lab/ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku) for the usage. - **[2024-12-08]** Support [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Please check [mit-han-lab/ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku) for the usage.
- **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library. - **[2024-11-07]** 🔥 Our latest **W4A4** Diffusion model quantization work [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) is publicly released! Check [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) for the quantization library.
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
<summary>更多动态</summary> <summary>更多动态</summary>
- **[2025-02-04]** **🚀 4-bit量化版[FLUX.1工具集](https://blackforestlabs.ai/flux-1-tools/)发布!** 相比原模型提速**2-3倍**[示例代码](./examples)已更新,**ComfyUI支持即将到来!** - **[2025-02-04]** **🚀 4-bit量化版[FLUX.1工具集](https://blackforestlabs.ai/flux-1-tools/)发布!** 相比原模型提速**2-3倍**[示例代码](./examples)已更新,**ComfyUI支持即将到来!**
- **[2025-01-23]** 🚀 **支持4-bit量化[SANA](https://nvlabs.github.io/Sana/)!** 相比16位模型提速2-3倍。[使用示例](./examples/sana_1600m_pag.py)[部署指南](app/sana/t2i)已发布,体验[在线演示](https://svdquant.mit.edu) - **[2025-01-23]** 🚀 **支持4-bit量化[SANA](https://nvlabs.github.io/Sana/)!** 相比16位模型提速2-3倍。[使用示例](examples/sana1.6b_pag.py)[部署指南](app/sana/t2i)已发布,体验[在线演示](https://svdquant.mit.edu)
- **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007)**ICLR 2025** 接收! - **[2025-01-22]** 🎉 [**SVDQuant**](http://arxiv.org/abs/2411.05007)**ICLR 2025** 接收!
- **[2024-12-08]** 支持 [ComfyUI](https://github.com/comfyanonymous/ComfyUI),详情见 [mit-han-lab/ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku) - **[2024-12-08]** 支持 [ComfyUI](https://github.com/comfyanonymous/ComfyUI),详情见 [mit-han-lab/ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku)
- **[2024-11-07]** 🔥 最新 **W4A4** 扩散模型量化工作 [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) 开源!量化库 [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) 同步发布。 - **[2024-11-07]** 🔥 最新 **W4A4** 扩散模型量化工作 [**SVDQuant**](https://hanlab.mit.edu/projects/svdquant) 开源!量化库 [**DeepCompressor**](https://github.com/mit-han-lab/deepcompressor) 同步发布。
......
...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-canny-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -7,7 +7,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-depth-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev", "black-forest-labs/FLUX.1-Depth-dev",
......
...@@ -6,7 +6,9 @@ from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe ...@@ -6,7 +6,9 @@ from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -15,7 +15,9 @@ controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend ...@@ -15,7 +15,9 @@ controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend
precision = get_precision() precision = get_precision()
need_offload = get_gpu_memory() < 36 need_offload = get_gpu_memory() < 36
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", torch_dtype=torch.bfloat16, offload=need_offload f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
torch_dtype=torch.bfloat16,
offload=need_offload,
) )
transformer.set_attention_impl("nunchaku-fp16") transformer.set_attention_impl("nunchaku-fp16")
......
...@@ -7,7 +7,9 @@ from nunchaku.utils import get_precision ...@@ -7,7 +7,9 @@ from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
transformer.set_attention_impl("nunchaku-fp16") # set attention implementation to fp16 transformer.set_attention_impl("nunchaku-fp16") # set attention implementation to fp16
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -6,7 +6,9 @@ from nunchaku.lora.flux.compose import compose_lora ...@@ -6,7 +6,9 @@ from nunchaku.lora.flux.compose import compose_lora
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", offload=True f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True
) # set offload to False if you want to disable offloading ) # set offload to False if you want to disable offloading
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -9,7 +9,9 @@ from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline ...@@ -9,7 +9,9 @@ from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = PuLIDFluxPipeline.from_pretrained( pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
......
...@@ -5,8 +5,10 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel ...@@ -5,8 +5,10 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_2, text_encoder_2=text_encoder_2,
......
...@@ -8,7 +8,9 @@ from nunchaku.caching.teacache import TeaCache ...@@ -8,7 +8,9 @@ from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True, offload=True,
torch_dtype=torch.float16, # Turing GPUs only support fp16 precision torch_dtype=torch.float16, # Turing GPUs only support fp16 precision
) # set offload to False if you want to disable offloading ) # set offload to False if you want to disable offloading
......
...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -9,7 +9,9 @@ image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev ...@@ -9,7 +9,9 @@ image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png") mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-fill-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors"
)
pipe = FluxFillPipeline.from_pretrained( pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -9,7 +9,9 @@ precision = get_precision() ...@@ -9,7 +9,9 @@ precision = get_precision()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipe = FluxPipeline.from_pretrained( pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
text_encoder=None, text_encoder=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment