#import tomesd import torch import time import os import pandas as pd from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionXLControlNetPipeline, DiffusionPipeline #from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp # import torch._dynamo # torch._dynamo.config.suppress_errors = True # torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cudnn.allow_tf32 = True dpm_sample_path = "./DPM-sample" os.makedirs(dpm_sample_path, exist_ok=True) generator = torch.manual_seed(2024) # model_id = "/data1/models/stablediffusion/stable-diffusion-xl-base-1.0" model_id = "/data1/models/stablediffusion/stable-diffusion-2-1-base" text_file = "PartiPrompts.tsv" df = pd.read_csv(text_file, sep='\t') prompts = df['Prompt'] num_inference_steps = 20 pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) # pipe = StableDiffusionXLControlNetPipeline.from_pretrained(model_id, torch_dtype=torch.float16) # pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") from DeepCache import DeepCacheSDHelper helper = DeepCacheSDHelper(pipe=pipe) helper.set_params( cache_interval=2, cache_branch_id=0, ) helper.enable() ################################### #pipe.unet = torch.compile(pipe.unet,mode="max-autotune-no-cudagraphs") # pipe.vae = torch.compile(pipe.vae,mode="max-autotune-no-cudagraphs") ################################### base_count = 0 print("======================================start DPM ==================================") for prompt in prompts: start = time.time() image = pipe(prompt, 512, 512, num_inference_steps=num_inference_steps, num_images_per_prompt=1, generator=generator).images[0] # image = pipe(prompt).images[0] print(f"the {base_count} text-to-image use time {time.time()-start}") image.save(os.path.join(dpm_sample_path, f"{base_count:05}.png")) base_count += 1 if base_count == 50: break print(f"Your samples are ready and waiting for you here\n{dpm_sample_path} \n" f" \nEnjoy.")