Commit e0ffd99d authored by muyangli's avatar muyangli
Browse files

add batch inference

parent 30ba84c5
......@@ -45,7 +45,7 @@ LORA_PATH_MAP = {
}
def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
def run_pipeline(dataset, batch_size: int, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
......@@ -61,43 +61,61 @@ def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forw
assert task in ["t2i", "fill"]
processor = None
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
for row in tqdm(dataloader):
filenames = row["filename"]
prompts = row["prompt"]
_forward_kwargs = {k: v for k, v in forward_kwargs.items()}
if task == "canny":
assert forward_kwargs.get("height", 1024) == 1024
assert forward_kwargs.get("width", 1024) == 1024
control_image = load_image(row["canny_image_path"])
control_image = processor(
control_image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024,
)
_forward_kwargs["control_image"] = control_image
control_images = []
for canny_image_path in row["canny_image_path"]:
control_image = load_image(canny_image_path)
control_image = processor(
control_image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024,
)
control_images.append(control_image)
_forward_kwargs["control_image"] = control_images
elif task == "depth":
control_image = load_image(row["depth_image_path"])
control_image = processor(control_image)[0].convert("RGB")
_forward_kwargs["control_image"] = control_image
control_images = []
for depth_image_path in row["depth_image_path"]:
control_image = load_image(depth_image_path)
control_image = processor(control_image)[0].convert("RGB")
control_images.append(control_image)
_forward_kwargs["control_image"] = control_images
elif task == "fill":
image = load_image(row["image_path"])
mask_image = load_image(row["mask_image_path"])
_forward_kwargs["image"] = image
_forward_kwargs["mask_image"] = mask_image
images, mask_images = [], []
for image_path, mask_image_path in zip(row["image_path"], row["mask_image_path"]):
image = load_image(image_path)
mask_image = load_image(mask_image_path)
images.append(image)
mask_images.append(mask_image)
_forward_kwargs["image"] = images
_forward_kwargs["mask_image"] = mask_images
elif task == "redux":
image = load_image(row["image_path"])
_forward_kwargs.update(processor(image))
images = []
for image_path in row["image_path"]:
image = load_image(image_path)
images.append(image)
_forward_kwargs.update(processor(images))
seed = hash_str_to_int(filename)
seeds = [hash_str_to_int(filename) for filename in filenames]
generators = [torch.Generator().manual_seed(seed) for seed in seeds]
if task == "redux":
image = pipeline(generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
images = pipeline(generator=generators, **_forward_kwargs).images
else:
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
images = pipeline(prompts, generator=generators, **_forward_kwargs).images
for i, image in enumerate(images):
filename = filenames[i]
image.save(os.path.join(save_dir, f"{filename}.png"))
torch.cuda.empty_cache()
......@@ -105,6 +123,7 @@ def run_test(
precision: str = "int4",
model_name: str = "flux.1-schnell",
dataset_name: str = "MJHQ",
batch_size: int = 1,
task: str = "t2i",
dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype
height: int = 1024,
......@@ -185,6 +204,7 @@ def run_test(
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
run_pipeline(
batch_size=batch_size,
dataset=dataset,
task=task,
pipeline=pipeline,
......@@ -255,6 +275,7 @@ def run_test(
else:
pipeline = pipeline.to("cuda")
run_pipeline(
batch_size=batch_size,
dataset=dataset,
task=task,
pipeline=pipeline,
......@@ -272,4 +293,4 @@ def run_test(
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.25
assert lpips < expected_lpips * 1.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment