inference_bf16.py 2.16 KB
Newer Older
1
2
import argparse

3
import intel_extension_for_pytorch as ipex
4
5
import torch

6
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
7
8


9
10
11
12
parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
args = parser.parse_args()
13
14
15


device = "cpu"
16
17
prompt = "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"

18
model_id = "path-to-your-trained-model"
19
20
21
22
pipe = StableDiffusionPipeline.from_pretrained(model_id)
if args.dpm:
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
23
24

# to channels last
25
26
27
28
29
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
if pipe.requires_safety_checker:
    pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)
30
31

# optimize with ipex
32
33
34
35
36
37
38
39
40
41
42
43
sample = torch.randn(2, 4, 64, 64)
timestep = torch.rand(1) * 999
encoder_hidden_status = torch.randn(2, 77, 768)
input_example = (sample, timestep, encoder_hidden_status)
try:
    pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
except Exception:
    pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
if pipe.requires_safety_checker:
    pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
44
45
46
47

# compute
seed = 666
generator = torch.Generator(device).manual_seed(seed)
48
49
50
51
generate_kwargs = {"generator": generator}
if args.steps is not None:
    generate_kwargs["num_inference_steps"] = args.steps

52
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
53
    image = pipe(prompt, **generate_kwargs).images[0]
54

55
56
# save image
image.save("generated.png")