Unverified Commit 0bda1d7b authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Update TensorRT img2img community pipeline (#8899)



* Update TensorRT img2img pipeline
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Update TensorRT version installed
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* make style and quality
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

* Update examples/community/stable_diffusion_tensorrt_img2img.py
Co-authored-by: default avatarTolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>

* Update examples/community/README.md
Co-authored-by: default avatarTolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>

* Apply style and quality using ruff 0.1.5
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>

---------
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Co-authored-by: default avatarTolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 527430d0
...@@ -1641,18 +1641,18 @@ from io import BytesIO ...@@ -1641,18 +1641,18 @@ from io import BytesIO
from PIL import Image from PIL import Image
import torch import torch
from diffusers import DDIMScheduler from diffusers import DDIMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionImg2ImgPipeline from diffusers import DiffusionPipeline
# Use the DDIMScheduler scheduler here instead # Use the DDIMScheduler scheduler here instead
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
subfolder="scheduler") subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_img2img", custom_pipeline="stable_diffusion_tensorrt_img2img",
variant='fp16', variant='fp16',
torch_dtype=torch.float16, torch_dtype=torch.float16,
scheduler=scheduler,) scheduler=scheduler,)
# re-use cached folder to save ONNX models and TensorRT Engines # re-use cached folder to save ONNX models and TensorRT Engines
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',) pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
import gc import gc
import os import os
from collections import OrderedDict from collections import OrderedDict
from copy import copy from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import numpy as np import numpy as np
import onnx import onnx
...@@ -27,9 +26,11 @@ import onnx_graphsurgeon as gs ...@@ -27,9 +26,11 @@ import onnx_graphsurgeon as gs
import PIL.Image import PIL.Image
import tensorrt as trt import tensorrt as trt
import torch import torch
from cuda import cudart
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference from onnx import shape_inference
from packaging import version
from polygraphy import cuda from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants from polygraphy.backend.onnx.loader import fold_constants
...@@ -41,12 +42,13 @@ from polygraphy.backend.trt import ( ...@@ -41,12 +42,13 @@ from polygraphy.backend.trt import (
network_from_onnx_path, network_from_onnx_path,
save_engine, save_engine,
) )
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
...@@ -58,7 +60,7 @@ from diffusers.utils import logging ...@@ -58,7 +60,7 @@ from diffusers.utils import logging
""" """
Installation instructions Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0 python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt>=8.6.1 python3 -m pip install --upgrade tensorrt-cu12==10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime python3 -m pip install onnxruntime
""" """
...@@ -88,10 +90,6 @@ else: ...@@ -88,10 +90,6 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
def preprocess_image(image): def preprocess_image(image):
""" """
image: torch.Tensor image: torch.Tensor
...@@ -125,10 +123,8 @@ class Engine: ...@@ -125,10 +123,8 @@ class Engine:
onnx_path, onnx_path,
fp16, fp16,
input_profile=None, input_profile=None,
enable_preview=False,
enable_all_tactics=False, enable_all_tactics=False,
timing_cache=None, timing_cache=None,
workspace_size=0,
): ):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile() p = Profile()
...@@ -137,20 +133,13 @@ class Engine: ...@@ -137,20 +133,13 @@ class Engine:
assert len(dims) == 3 assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2]) p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {} extra_build_args = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics: if not enable_all_tactics:
config_kwargs["tactic_sources"] = [] extra_build_args["tactic_sources"] = []
engine = engine_from_network( engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
save_timing_cache=timing_cache, save_timing_cache=timing_cache,
) )
save_engine(engine, path=self.engine_path) save_engine(engine, path=self.engine_path)
...@@ -163,28 +152,24 @@ class Engine: ...@@ -163,28 +152,24 @@ class Engine:
self.context = self.engine.create_execution_context() self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"): def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(trt_util.get_bindings_per_profile(self.engine)): for binding in range(self.engine.num_io_tensors):
binding = self.engine[idx] name = self.engine.get_tensor_name(binding)
if shape_dict and binding in shape_dict: if shape_dict and name in shape_dict:
shape = shape_dict[binding] shape = shape_dict[name]
else: else:
shape = self.engine.get_binding_shape(binding) shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_binding_dtype(binding)) dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.binding_is_input(binding): if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_binding_shape(idx, shape) self.context.set_input_shape(name, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor self.tensors[name] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
def infer(self, feed_dict, stream): def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items(): for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView) self.tensors[name].copy_(buf)
device_buffers[name] = buf for name, tensor in self.tensors.items():
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) noerror = self.context.execute_async_v3(stream)
if not noerror: if not noerror:
raise ValueError("ERROR: inference failed.") raise ValueError("ERROR: inference failed.")
...@@ -325,10 +310,8 @@ def build_engines( ...@@ -325,10 +310,8 @@ def build_engines(
force_engine_rebuild=False, force_engine_rebuild=False,
static_batch=False, static_batch=False,
static_shape=True, static_shape=True,
enable_preview=False,
enable_all_tactics=False, enable_all_tactics=False,
timing_cache=None, timing_cache=None,
max_workspace_size=0,
): ):
built_engines = {} built_engines = {}
if not os.path.isdir(onnx_dir): if not os.path.isdir(onnx_dir):
...@@ -393,9 +376,7 @@ def build_engines( ...@@ -393,9 +376,7 @@ def build_engines(
static_batch=static_batch, static_batch=static_batch,
static_shape=static_shape, static_shape=static_shape,
), ),
enable_preview=enable_preview,
timing_cache=timing_cache, timing_cache=timing_cache,
workspace_size=max_workspace_size,
) )
built_engines[model_name] = engine built_engines[model_name] = engine
...@@ -674,7 +655,7 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False) ...@@ -674,7 +655,7 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False)
return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
r""" r"""
Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
...@@ -702,6 +683,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -702,6 +683,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -722,24 +705,86 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -722,24 +705,86 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
onnx_dir: str = "onnx", onnx_dir: str = "onnx",
# TensorRT engine build parameters # TensorRT engine build parameters
engine_dir: str = "engine", engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False, force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache", timing_cache: str = "timing_cache",
): ):
super().__init__( super().__init__()
vae,
text_encoder, if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
tokenizer, deprecation_message = (
unet, f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
scheduler, f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.vae.forward = self.vae.decode
self.stages = stages self.stages = stages
self.image_height, self.image_width = image_height, image_width self.image_height, self.image_width = image_height, image_width
self.inpaint = False self.inpaint = False
...@@ -750,7 +795,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -750,7 +795,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
self.timing_cache = timing_cache self.timing_cache = timing_cache
self.build_static_batch = False self.build_static_batch = False
self.build_dynamic_shape = False self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
...@@ -761,6 +805,11 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -761,6 +805,11 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
self.models = {} # loaded in __loadModels() self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines() self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self): def __loadModels(self):
# Load pipeline models # Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size self.embedding_dim = self.text_encoder.config.hidden_size
...@@ -779,6 +828,33 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -779,6 +828,33 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
if "vae_encoder" in self.stages: if "vae_encoder" in self.stages:
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
...@@ -826,7 +902,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -826,7 +902,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
force_engine_rebuild=self.force_engine_rebuild, force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch, static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape, static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache, timing_cache=self.timing_cache,
) )
...@@ -850,9 +925,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -850,9 +925,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
return tuple(init_images) return tuple(init_images)
def __encode_image(self, init_image): def __encode_image(self, init_image):
init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
"latent"
]
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents return init_latents
...@@ -881,9 +954,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -881,9 +954,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
.to(self.torch_device) .to(self.torch_device)
) )
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
"text_embeddings" "text_embeddings"
].clone() ].clone()
...@@ -899,8 +971,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -899,8 +971,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
.input_ids.type(torch.int32) .input_ids.type(torch.int32)
.to(self.torch_device) .to(self.torch_device)
) )
uncond_input_ids_inp = device_view(uncond_input_ids) uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
"text_embeddings" "text_embeddings"
] ]
...@@ -924,18 +995,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -924,18 +995,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# Predict the noise residual # Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine( noise_pred = runEngine(
self.engine["unet"], self.engine["unet"],
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
self.stream, self.stream,
)["latent"] )["latent"]
# Perform guidance # Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
...@@ -943,12 +1011,12 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -943,12 +1011,12 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
return latents return latents
def __decode_latent(self, latents): def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1) images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy() return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size): def __loadResources(self, image_height, image_width, batch_size):
self.stream = cuda.Stream() self.stream = cudart.cudaStreamCreate()[1]
# Allocate buffers for TensorRT engine bindings # Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items(): for model_name, obj in self.models.items():
...@@ -1061,5 +1129,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -1061,5 +1129,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# VAE decode latent # VAE decode latent
images = self.__decode_latent(latents) images = self.__decode_latent(latents)
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
images = self.numpy_to_pil(images) images = self.numpy_to_pil(images)
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment