Unverified Commit 7b6caca9 authored by mengfei25's avatar mengfei25 Committed by GitHub
Browse files

Modify example with intel optimization (#2896)

* modify intel opts inference script

* modify readme

* modify doc

* fix some issues

* reformat

* reformat script

* format issue

* format issue
parent f3fbf9bf
......@@ -11,6 +11,26 @@ We accelereate the fine-tuning for textual inversion with Intel Extension for Py
## Accelerating the inference for Stable Diffusion using Bfloat16
We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
```bash
pip install diffusers transformers accelerate scipy safetensors
export KMP_BLOCKTIME=1
export KMP_SETTINGS=1
export KMP_AFFINITY=granularity=fine,compact,1,0
# Intel OpenMP
export OMP_NUM_THREADS=< Cores to use >
export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so
# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000"
# Launch with default DDIM
numactl --membind <node N> -C <cpu list> python python inference_bf16.py
# Launch with DPMSolverMultistepScheduler
numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm
```
## Accelerating the inference for Stable Diffusion using INT8
......
import argparse
import intel_extension_for_pytorch as ipex
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
args = parser.parse_args()
prompt = ["a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"]
batch_size = 8
prompt = prompt * batch_size
device = "cpu"
prompt = "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"
model_id = "path-to-your-trained-model"
model = StableDiffusionPipeline.from_pretrained(model_id)
model = model.to(device)
pipe = StableDiffusionPipeline.from_pretrained(model_id)
if args.dpm:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
# to channels last
model.unet = model.unet.to(memory_format=torch.channels_last)
model.vae = model.vae.to(memory_format=torch.channels_last)
model.text_encoder = model.text_encoder.to(memory_format=torch.channels_last)
model.safety_checker = model.safety_checker.to(memory_format=torch.channels_last)
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
if pipe.requires_safety_checker:
pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)
# optimize with ipex
model.unet = ipex.optimize(model.unet.eval(), dtype=torch.bfloat16, inplace=True)
model.vae = ipex.optimize(model.vae.eval(), dtype=torch.bfloat16, inplace=True)
model.text_encoder = ipex.optimize(model.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
model.safety_checker = ipex.optimize(model.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
sample = torch.randn(2, 4, 64, 64)
timestep = torch.rand(1) * 999
encoder_hidden_status = torch.randn(2, 77, 768)
input_example = (sample, timestep, encoder_hidden_status)
try:
pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
except Exception:
pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
if pipe.requires_safety_checker:
pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
# compute
seed = 666
generator = torch.Generator(device).manual_seed(seed)
generate_kwargs = {"generator": generator}
if args.steps is not None:
generate_kwargs["num_inference_steps"] = args.steps
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
images = model(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator).images
image = pipe(prompt, **generate_kwargs).images[0]
# save image
grid = image_grid(images, rows=2, cols=4)
grid.save(model_id + ".png")
# save image
image.save("generated.png")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment