Commit 513d7ed8 authored by muyangli's avatar muyangli
Browse files

[minor] support loading local models

parent 73edbaaa
......@@ -20,20 +20,21 @@ pip install -r requirements.txt
* Place the SVDQuant workflow configurations from [`workflows`](./workflows) into `user/default/workflows`.
* For example
```shell
# Clone repositories (skip if already cloned)
git clone https://github.com/comfyanonymous/ComfyUI.git
git clone https://github.com/mit-han-lab/nunchaku.git
cd ComfyUI
# Copy workflow configurations
mkdir -p user/default/workflows
cp ../nunchaku/comfyui/workflows/* user/default/workflows/
# Add SVDQuant nodes
cd custom_nodes
ln -s ../../nunchaku/comfyui svdquant
```
```shell
# Clone repositories (skip if already cloned)
git clone https://github.com/comfyanonymous/ComfyUI.git
git clone https://github.com/mit-han-lab/nunchaku.git
cd ComfyUI
# Copy workflow configurations
mkdir -p user/default/workflows
cp ../nunchaku/comfyui/workflows/* user/default/workflows/
# Add SVDQuant nodes
cd custom_nodes
ln -s ../../nunchaku/comfyui svdquant
```
2. **Download Required Models**: Follow [this tutorial](https://comfyanonymous.github.io/ComfyUI_examples/flux/) and download the required models into the appropriate directories using the commands below:
......@@ -55,7 +56,14 @@ pip install -r requirements.txt
* **SVDQuant Flux DiT Loader**: A node for loading the FLUX diffusion model.
* `model_path`: Specifies the model location. It can be set to either `mit-han-lab/svdq-int-flux.1-schnell` or `mit-han-lab/svdq-int-flux.1-dev`. The model will automatically download from our Hugging Face repository.
* `model_path`: Specifies the model location. If set to `mit-han-lab/svdq-int4-flux.1-schnell` or `mit-han-lab/svdq-int4-flux.1-dev`, the model will be automatically downloaded from our Hugging Face repository. Alternatively, you can manually download the model directory by running the following command:
```shell
huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev
```
After downloading, specify the corresponding folder name as the `model_path`.
* `device_id`: Indicates the GPU ID for running the model.
* **SVDQuant LoRA Loader**: A node for loading LoRA modules for SVDQuant diffusion models.
......@@ -70,9 +78,9 @@ pip install -r requirements.txt
- `text_encoder1`: `t5xxl_fp16.safetensors`
- `text_encoder2`: `clip_l.safetensors`
* **`t5_min_length`**: Sets the minimum sequence length for T5 text embeddings. The default in `DualCLIPLoader` is hardcoded to 256, but for better image quality in SVDQuant, use 512 here.
* `t5_min_length`: Sets the minimum sequence length for T5 text embeddings. The default in `DualCLIPLoader` is hardcoded to 256, but for better image quality in SVDQuant, use 512 here.
* **`t5_precision`**: Specifies the precision of the T5 text encoder. Choose `INT4` to use the INT4 text encoder, which reduces GPU memory usage by approximately 15GB. Please install [`deepcompressor`](https://github.com/mit-han-lab/deepcompressor) when using it:
* `t5_precision`: Specifies the precision of the T5 text encoder. Choose `INT4` to use the INT4 text encoder, which reduces GPU memory usage by approximately 15GB. Please install [`deepcompressor`](https://github.com/mit-han-lab/deepcompressor) when using it:
```shell
git clone https://github.com/mit-han-lab/deepcompressor
......@@ -81,4 +89,11 @@ pip install -r requirements.txt
poetry install
```
* `int4_model`: Specifies the INT4 model location. This option is only used when `t5_precision` is set to `INT4`. By default, the path is `mit-han-lab/svdq-flux.1-t5`, and the model will automatically download from our Hugging Face repository. Alternatively, you can manually download the model directory by running the following command:
```shell
huggingface-cli download mit-han-lab/svdq-flux.1-t5 --local-dir models/text_encoders/svdq-flux.1-t5
```
After downloading, specify the corresponding folder name as the `int4_model`.
......@@ -62,6 +62,16 @@ class SVDQuantFluxDiTLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-int4-flux.1-schnell", "mit-han-lab/svdq-int4-flux.1-dev"]
prefix = "models/diffusion_models"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
ngpus = len(GPUtil.getGPUs())
return {
"required": {
......@@ -80,6 +90,11 @@ class SVDQuantFluxDiTLoader:
def load_model(self, model_path: str, device_id: int, **kwargs) -> tuple[FluxTransformer2DModel]:
device = f"cuda:{device_id}"
prefix = "models/diffusion_models"
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
else:
model_path = model_path
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(device)
dit_config = {
"image_model": "flux",
......@@ -135,6 +150,17 @@ def svdquant_t5_forward(
class SVDQuantTextEncoderLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-flux.1-t5"]
prefix = "models/text_encoders"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
return {
"required": {
"model_type": (["flux"],),
......@@ -145,6 +171,7 @@ class SVDQuantTextEncoderLoader:
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
),
"t5_precision": (["BF16", "INT4"],),
"int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}),
}
}
......@@ -156,7 +183,13 @@ class SVDQuantTextEncoderLoader:
TITLE = "SVDQuant Text Encoder Loader"
def load_text_encoder(
self, model_type: str, text_encoder1: str, text_encoder2: str, t5_min_length: int, t5_precision: str
self,
model_type: str,
text_encoder1: str,
text_encoder2: str,
t5_min_length: int,
t5_precision: str,
int4_model: str,
):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
text_encoder_path2 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder2)
......@@ -181,7 +214,13 @@ class SVDQuantTextEncoderLoader:
param = next(transformer.parameters())
dtype = param.dtype
device = param.device
transformer = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
prefix = "models/text_encoders"
if os.path.exists(os.path.join(prefix, int4_model)):
model_path = os.path.join(prefix, int4_model)
else:
model_path = int4_model
transformer = NunchakuT5EncoderModel.from_pretrained(model_path)
transformer.forward = types.MethodType(svdquant_t5_forward, transformer)
clip.cond_stage_model.t5xxl.transformer = (
transformer.to(device=device, dtype=dtype) if device.type == "cuda" else transformer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment