DPM_diffusers.py 2.21 KB
Newer Older
wangkaixiong's avatar
init  
wangkaixiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#import tomesd
import torch
import time
import os
import pandas as pd
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionXLControlNetPipeline, DiffusionPipeline
#from xformers.ops import MemoryEfficientAttentionFlashAttentionOp, MemoryEfficientAttentionTritonFwdFlashBwOp

# import torch._dynamo
# torch._dynamo.config.suppress_errors = True

# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
dpm_sample_path = "./DPM-sample"
os.makedirs(dpm_sample_path, exist_ok=True)

generator = torch.manual_seed(2024)
# model_id = "/data1/models/stablediffusion/stable-diffusion-xl-base-1.0"
model_id = "/data1/models/stablediffusion/stable-diffusion-2-1-base"
text_file = "PartiPrompts.tsv"

df = pd.read_csv(text_file, sep='\t')
prompts = df['Prompt']
num_inference_steps = 20
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe = StableDiffusionXLControlNetPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

from DeepCache import DeepCacheSDHelper
helper = DeepCacheSDHelper(pipe=pipe)
helper.set_params(
     cache_interval=2,
     cache_branch_id=0,
 )
helper.enable()
###################################
#pipe.unet = torch.compile(pipe.unet,mode="max-autotune-no-cudagraphs")
# pipe.vae = torch.compile(pipe.vae,mode="max-autotune-no-cudagraphs")
###################################

base_count = 0
print("======================================start DPM ==================================")
for prompt in prompts:
    start = time.time()
    image = pipe(prompt, 512, 512, num_inference_steps=num_inference_steps, num_images_per_prompt=1, generator=generator).images[0]
    # image = pipe(prompt).images[0]
    print(f"the {base_count} text-to-image use time {time.time()-start}")
    image.save(os.path.join(dpm_sample_path, f"{base_count:05}.png"))
    base_count += 1
    if base_count == 50:
        break
print(f"Your samples are ready and waiting for you here\n{dpm_sample_path} \n"
          f" \nEnjoy.")