Unverified Commit 3dc10a53 authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Update TensorRT txt2img and inpaint community pipelines (#9037)



* Update TensorRT txt2img and inpaint community pipelines
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* update tensorrt install instructions
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

---------
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c370b90f
...@@ -1487,17 +1487,16 @@ NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes. ...@@ -1487,17 +1487,16 @@ NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
```python ```python
import torch import torch
from diffusers import DDIMScheduler from diffusers import DDIMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline from diffusers.pipelines import DiffusionPipeline
# Use the DDIMScheduler scheduler here instead # Use the DDIMScheduler scheduler here instead
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_txt2img", custom_pipeline="stable_diffusion_tensorrt_txt2img",
variant='fp16', variant='fp16',
torch_dtype=torch.float16, torch_dtype=torch.float16,
scheduler=scheduler,) scheduler=scheduler,)
# re-use cached folder to save ONNX models and TensorRT Engines # re-use cached folder to save ONNX models and TensorRT Engines
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',) pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
...@@ -2231,12 +2230,12 @@ from io import BytesIO ...@@ -2231,12 +2230,12 @@ from io import BytesIO
from PIL import Image from PIL import Image
import torch import torch
from diffusers import PNDMScheduler from diffusers import PNDMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline from diffusers.pipelines import DiffusionPipeline
# Use the PNDMScheduler scheduler here instead # Use the PNDMScheduler scheduler here instead
scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler") scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
custom_pipeline="stable_diffusion_tensorrt_inpaint", custom_pipeline="stable_diffusion_tensorrt_inpaint",
variant='fp16', variant='fp16',
torch_dtype=torch.float16, torch_dtype=torch.float16,
......
...@@ -60,7 +60,7 @@ from diffusers.utils import logging ...@@ -60,7 +60,7 @@ from diffusers.utils import logging
""" """
Installation instructions Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0 python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt-cu12==10.2.0 python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime python3 -m pip install onnxruntime
""" """
...@@ -659,7 +659,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -659,7 +659,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
r""" r"""
Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args: Args:
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
import gc import gc
import os import os
from collections import OrderedDict from collections import OrderedDict
from copy import copy from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import numpy as np import numpy as np
import onnx import onnx
...@@ -27,9 +26,11 @@ import onnx_graphsurgeon as gs ...@@ -27,9 +26,11 @@ import onnx_graphsurgeon as gs
import PIL.Image import PIL.Image
import tensorrt as trt import tensorrt as trt
import torch import torch
from cuda import cudart
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference from onnx import shape_inference
from packaging import version
from polygraphy import cuda from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants from polygraphy.backend.onnx.loader import fold_constants
...@@ -41,24 +42,29 @@ from polygraphy.backend.trt import ( ...@@ -41,24 +42,29 @@ from polygraphy.backend.trt import (
network_from_onnx_path, network_from_onnx_path,
save_engine, save_engine,
) )
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
StableDiffusionInpaintPipeline,
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
prepare_mask_and_masked_image,
retrieve_latents,
)
from diffusers.schedulers import DDIMScheduler from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
""" """
Installation instructions Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0 python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt>=8.6.1 python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime python3 -m pip install onnxruntime
""" """
...@@ -88,10 +94,6 @@ else: ...@@ -88,10 +94,6 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
def preprocess_image(image): def preprocess_image(image):
""" """
image: torch.Tensor image: torch.Tensor
...@@ -125,10 +127,8 @@ class Engine: ...@@ -125,10 +127,8 @@ class Engine:
onnx_path, onnx_path,
fp16, fp16,
input_profile=None, input_profile=None,
enable_preview=False,
enable_all_tactics=False, enable_all_tactics=False,
timing_cache=None, timing_cache=None,
workspace_size=0,
): ):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile() p = Profile()
...@@ -137,20 +137,13 @@ class Engine: ...@@ -137,20 +137,13 @@ class Engine:
assert len(dims) == 3 assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2]) p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {} extra_build_args = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics: if not enable_all_tactics:
config_kwargs["tactic_sources"] = [] extra_build_args["tactic_sources"] = []
engine = engine_from_network( engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
save_timing_cache=timing_cache, save_timing_cache=timing_cache,
) )
save_engine(engine, path=self.engine_path) save_engine(engine, path=self.engine_path)
...@@ -163,28 +156,24 @@ class Engine: ...@@ -163,28 +156,24 @@ class Engine:
self.context = self.engine.create_execution_context() self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"): def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(trt_util.get_bindings_per_profile(self.engine)): for binding in range(self.engine.num_io_tensors):
binding = self.engine[idx] name = self.engine.get_tensor_name(binding)
if shape_dict and binding in shape_dict: if shape_dict and name in shape_dict:
shape = shape_dict[binding] shape = shape_dict[name]
else: else:
shape = self.engine.get_binding_shape(binding) shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_binding_dtype(binding)) dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.binding_is_input(binding): if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_binding_shape(idx, shape) self.context.set_input_shape(name, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor self.tensors[name] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
def infer(self, feed_dict, stream): def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items(): for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView) self.tensors[name].copy_(buf)
device_buffers[name] = buf for name, tensor in self.tensors.items():
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) noerror = self.context.execute_async_v3(stream)
if not noerror: if not noerror:
raise ValueError("ERROR: inference failed.") raise ValueError("ERROR: inference failed.")
...@@ -325,10 +314,8 @@ def build_engines( ...@@ -325,10 +314,8 @@ def build_engines(
force_engine_rebuild=False, force_engine_rebuild=False,
static_batch=False, static_batch=False,
static_shape=True, static_shape=True,
enable_preview=False,
enable_all_tactics=False, enable_all_tactics=False,
timing_cache=None, timing_cache=None,
max_workspace_size=0,
): ):
built_engines = {} built_engines = {}
if not os.path.isdir(onnx_dir): if not os.path.isdir(onnx_dir):
...@@ -393,9 +380,7 @@ def build_engines( ...@@ -393,9 +380,7 @@ def build_engines(
static_batch=static_batch, static_batch=static_batch,
static_shape=static_shape, static_shape=static_shape,
), ),
enable_preview=enable_preview,
timing_cache=timing_cache, timing_cache=timing_cache,
workspace_size=max_workspace_size,
) )
built_engines[model_name] = engine built_engines[model_name] = engine
...@@ -674,11 +659,11 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False) ...@@ -674,11 +659,11 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False)
return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
r""" r"""
Pipeline for inpainting using TensorRT accelerated Stable Diffusion. Pipeline for inpainting using TensorRT accelerated Stable Diffusion.
This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args: Args:
...@@ -702,6 +687,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -702,6 +687,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -722,24 +709,86 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -722,24 +709,86 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
onnx_dir: str = "onnx", onnx_dir: str = "onnx",
# TensorRT engine build parameters # TensorRT engine build parameters
engine_dir: str = "engine", engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False, force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache", timing_cache: str = "timing_cache",
): ):
super().__init__( super().__init__()
vae,
text_encoder, if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
tokenizer, deprecation_message = (
unet, f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
scheduler, f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.vae.forward = self.vae.decode
self.stages = stages self.stages = stages
self.image_height, self.image_width = image_height, image_width self.image_height, self.image_width = image_height, image_width
self.inpaint = True self.inpaint = True
...@@ -750,7 +799,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -750,7 +799,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.timing_cache = timing_cache self.timing_cache = timing_cache
self.build_static_batch = False self.build_static_batch = False
self.build_dynamic_shape = False self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
...@@ -761,6 +809,11 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -761,6 +809,11 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.models = {} # loaded in __loadModels() self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines() self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self): def __loadModels(self):
# Load pipeline models # Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size self.embedding_dim = self.text_encoder.config.hidden_size
...@@ -779,6 +832,112 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -779,6 +832,112 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
if "vae_encoder" in self.stages: if "vae_encoder" in self.stages:
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
return image_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
return_noise=False,
return_image_latents=False,
):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
else:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
...@@ -826,7 +985,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -826,7 +985,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
force_engine_rebuild=self.force_engine_rebuild, force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch, static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape, static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache, timing_cache=self.timing_cache,
) )
...@@ -850,9 +1008,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -850,9 +1008,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
return tuple(init_images) return tuple(init_images)
def __encode_image(self, init_image): def __encode_image(self, init_image):
init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
"latent"
]
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents return init_latents
...@@ -881,9 +1037,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -881,9 +1037,8 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
.to(self.torch_device) .to(self.torch_device)
) )
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
"text_embeddings" "text_embeddings"
].clone() ].clone()
...@@ -899,8 +1054,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -899,8 +1054,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
.input_ids.type(torch.int32) .input_ids.type(torch.int32)
.to(self.torch_device) .to(self.torch_device)
) )
uncond_input_ids_inp = device_view(uncond_input_ids) uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
"text_embeddings" "text_embeddings"
] ]
...@@ -924,18 +1078,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -924,18 +1078,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
# Predict the noise residual # Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine( noise_pred = runEngine(
self.engine["unet"], self.engine["unet"],
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
self.stream, self.stream,
)["latent"] )["latent"]
# Perform guidance # Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
...@@ -943,12 +1094,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -943,12 +1094,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
return latents return latents
def __decode_latent(self, latents): def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1) images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy() return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size): def __loadResources(self, image_height, image_width, batch_size):
self.stream = cuda.Stream() self.stream = cudart.cudaStreamCreate()[1]
# Allocate buffers for TensorRT engine bindings # Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items(): for model_name, obj in self.models.items():
...@@ -1112,5 +1263,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -1112,5 +1263,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
# VAE decode latent # VAE decode latent
images = self.__decode_latent(latents) images = self.__decode_latent(latents)
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
images = self.numpy_to_pil(images) images = self.numpy_to_pil(images)
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
...@@ -18,17 +18,19 @@ ...@@ -18,17 +18,19 @@
import gc import gc
import os import os
from collections import OrderedDict from collections import OrderedDict
from copy import copy from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import numpy as np import numpy as np
import onnx import onnx
import onnx_graphsurgeon as gs import onnx_graphsurgeon as gs
import PIL.Image
import tensorrt as trt import tensorrt as trt
import torch import torch
from cuda import cudart
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference from onnx import shape_inference
from packaging import version
from polygraphy import cuda from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants from polygraphy.backend.onnx.loader import fold_constants
...@@ -40,23 +42,25 @@ from polygraphy.backend.trt import ( ...@@ -40,23 +42,25 @@ from polygraphy.backend.trt import (
network_from_onnx_path, network_from_onnx_path,
save_engine, save_engine,
) )
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipeline,
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from diffusers.schedulers import DDIMScheduler from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
""" """
Installation instructions Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0 python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt>=8.6.1 python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime python3 -m pip install onnxruntime
""" """
...@@ -86,10 +90,6 @@ else: ...@@ -86,10 +90,6 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
class Engine: class Engine:
def __init__(self, engine_path): def __init__(self, engine_path):
self.engine_path = engine_path self.engine_path = engine_path
...@@ -110,10 +110,8 @@ class Engine: ...@@ -110,10 +110,8 @@ class Engine:
onnx_path, onnx_path,
fp16, fp16,
input_profile=None, input_profile=None,
enable_preview=False,
enable_all_tactics=False, enable_all_tactics=False,
timing_cache=None, timing_cache=None,
workspace_size=0,
): ):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile() p = Profile()
...@@ -122,20 +120,13 @@ class Engine: ...@@ -122,20 +120,13 @@ class Engine:
assert len(dims) == 3 assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2]) p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {} extra_build_args = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics: if not enable_all_tactics:
config_kwargs["tactic_sources"] = [] extra_build_args["tactic_sources"] = []
engine = engine_from_network( engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
save_timing_cache=timing_cache, save_timing_cache=timing_cache,
) )
save_engine(engine, path=self.engine_path) save_engine(engine, path=self.engine_path)
...@@ -148,28 +139,24 @@ class Engine: ...@@ -148,28 +139,24 @@ class Engine:
self.context = self.engine.create_execution_context() self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"): def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(trt_util.get_bindings_per_profile(self.engine)): for binding in range(self.engine.num_io_tensors):
binding = self.engine[idx] name = self.engine.get_tensor_name(binding)
if shape_dict and binding in shape_dict: if shape_dict and name in shape_dict:
shape = shape_dict[binding] shape = shape_dict[name]
else: else:
shape = self.engine.get_binding_shape(binding) shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_binding_dtype(binding)) dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.binding_is_input(binding): if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_binding_shape(idx, shape) self.context.set_input_shape(name, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor self.tensors[name] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
def infer(self, feed_dict, stream): def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items(): for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView) self.tensors[name].copy_(buf)
device_buffers[name] = buf for name, tensor in self.tensors.items():
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) noerror = self.context.execute_async_v3(stream)
if not noerror: if not noerror:
raise ValueError("ERROR: inference failed.") raise ValueError("ERROR: inference failed.")
...@@ -310,10 +297,8 @@ def build_engines( ...@@ -310,10 +297,8 @@ def build_engines(
force_engine_rebuild=False, force_engine_rebuild=False,
static_batch=False, static_batch=False,
static_shape=True, static_shape=True,
enable_preview=False,
enable_all_tactics=False, enable_all_tactics=False,
timing_cache=None, timing_cache=None,
max_workspace_size=0,
): ):
built_engines = {} built_engines = {}
if not os.path.isdir(onnx_dir): if not os.path.isdir(onnx_dir):
...@@ -378,9 +363,7 @@ def build_engines( ...@@ -378,9 +363,7 @@ def build_engines(
static_batch=static_batch, static_batch=static_batch,
static_shape=static_shape, static_shape=static_shape,
), ),
enable_preview=enable_preview,
timing_cache=timing_cache, timing_cache=timing_cache,
workspace_size=max_workspace_size,
) )
built_engines[model_name] = engine built_engines[model_name] = engine
...@@ -588,11 +571,11 @@ def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): ...@@ -588,11 +571,11 @@ def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): class TensorRTStableDiffusionPipeline(DiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion. Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args: Args:
...@@ -616,6 +599,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -616,6 +599,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -632,28 +617,90 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -632,28 +617,90 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
image_width: int = 768, image_width: int = 768,
max_batch_size: int = 16, max_batch_size: int = 16,
# ONNX export parameters # ONNX export parameters
onnx_opset: int = 17, onnx_opset: int = 18,
onnx_dir: str = "onnx", onnx_dir: str = "onnx",
# TensorRT engine build parameters # TensorRT engine build parameters
engine_dir: str = "engine", engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False, force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache", timing_cache: str = "timing_cache",
): ):
super().__init__( super().__init__()
vae,
text_encoder, if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
tokenizer, deprecation_message = (
unet, f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
scheduler, f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.vae.forward = self.vae.decode
self.stages = stages self.stages = stages
self.image_height, self.image_width = image_height, image_width self.image_height, self.image_width = image_height, image_width
self.inpaint = False self.inpaint = False
...@@ -664,7 +711,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -664,7 +711,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.timing_cache = timing_cache self.timing_cache = timing_cache
self.build_static_batch = False self.build_static_batch = False
self.build_dynamic_shape = False self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
...@@ -675,6 +721,11 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -675,6 +721,11 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.models = {} # loaded in __loadModels() self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines() self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self): def __loadModels(self):
# Load pipeline models # Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size self.embedding_dim = self.text_encoder.config.hidden_size
...@@ -691,6 +742,75 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -691,6 +742,75 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
if "vae" in self.stages: if "vae" in self.stages:
self.models["vae"] = make_VAE(self.vae, **models_args) self.models["vae"] = make_VAE(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Union[torch.Generator, List[torch.Generator]],
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Prepare the latent vectors for diffusion.
Args:
batch_size (int): The number of samples in the batch.
num_channels_latents (int): The number of channels in the latent vectors.
height (int): The height of the latent vectors.
width (int): The width of the latent vectors.
dtype (torch.dtype): The data type of the latent vectors.
device (torch.device): The device to place the latent vectors on.
generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.
latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.
Returns:
torch.Tensor: The prepared latent vectors.
"""
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
...@@ -738,7 +858,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -738,7 +858,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
force_engine_rebuild=self.force_engine_rebuild, force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch, static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape, static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache, timing_cache=self.timing_cache,
) )
...@@ -769,9 +888,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -769,9 +888,8 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
.to(self.torch_device) .to(self.torch_device)
) )
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
"text_embeddings" "text_embeddings"
].clone() ].clone()
...@@ -787,8 +905,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -787,8 +905,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
.input_ids.type(torch.int32) .input_ids.type(torch.int32)
.to(self.torch_device) .to(self.torch_device)
) )
uncond_input_ids_inp = device_view(uncond_input_ids) uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
"text_embeddings" "text_embeddings"
] ]
...@@ -812,18 +929,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -812,18 +929,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
# Predict the noise residual # Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine( noise_pred = runEngine(
self.engine["unet"], self.engine["unet"],
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
self.stream, self.stream,
)["latent"] )["latent"]
# Perform guidance # Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
...@@ -831,12 +945,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -831,12 +945,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
return latents return latents
def __decode_latent(self, latents): def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1) images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy() return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size): def __loadResources(self, image_height, image_width, batch_size):
self.stream = cuda.Stream() self.stream = cudart.cudaStreamCreate()[1]
# Allocate buffers for TensorRT engine bindings # Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items(): for model_name, obj in self.models.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment