Commit 3ecf0d25 authored by muyangli's avatar muyangli
Browse files

update

parent e0ffd99d
......@@ -2,17 +2,17 @@ import torch
from diffusers import FluxPipeline
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
if __name__ == "__main__":
capability = torch.cuda.get_device_capability(0)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
precision = get_precision()
torch_dtype = torch.float16 if is_turing() else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=True
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, offload=True
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
)
pipeline.enable_sequential_cpu_offload()
image = pipeline(
......
......@@ -61,9 +61,7 @@ def run_pipeline(dataset, batch_size: int, task: str, pipeline: FluxPipeline, sa
assert task in ["t2i", "fill"]
processor = None
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
for row in tqdm(dataloader):
for row in tqdm(dataset.iter(batch_size=batch_size, drop_last_batch=False)):
filenames = row["filename"]
prompts = row["prompt"]
......@@ -234,6 +232,8 @@ def run_test(
precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
if batch_size > 1:
precision_str += f"-bs{batch_size}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment