Commit 3ecf0d25 authored by muyangli's avatar muyangli
Browse files

update

parent e0ffd99d
...@@ -2,17 +2,17 @@ import torch ...@@ -2,17 +2,17 @@ import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
if __name__ == "__main__": if __name__ == "__main__":
capability = torch.cuda.get_device_capability(0) precision = get_precision()
sm = f"{capability[0]}{capability[1]}" torch_dtype = torch.float16 if is_turing() else torch.bfloat16
precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=True f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, offload=True
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
) )
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
image = pipeline( image = pipeline(
......
...@@ -61,9 +61,7 @@ def run_pipeline(dataset, batch_size: int, task: str, pipeline: FluxPipeline, sa ...@@ -61,9 +61,7 @@ def run_pipeline(dataset, batch_size: int, task: str, pipeline: FluxPipeline, sa
assert task in ["t2i", "fill"] assert task in ["t2i", "fill"]
processor = None processor = None
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) for row in tqdm(dataset.iter(batch_size=batch_size, drop_last_batch=False)):
for row in tqdm(dataloader):
filenames = row["filename"] filenames = row["filename"]
prompts = row["prompt"] prompts = row["prompt"]
...@@ -234,6 +232,8 @@ def run_test( ...@@ -234,6 +232,8 @@ def run_test(
precision_str += f"-cache{cache_threshold}" precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None: if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}" precision_str += f"-i2f{i2f_mode}"
if batch_size > 1:
precision_str += f"-bs{batch_size}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name) save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size): if not already_generate(save_dir_4bit, max_dataset_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment