Unverified Commit 0b42b074 authored by SkyTNT's avatar SkyTNT Committed by GitHub
Browse files

[Onnx] support half-precision and fix bugs for onnx pipelines (#932)

* [Onnx] support half-precision and fix bugs for onnx pipelines

* Update convert_stable_diffusion_checkpoint_to_onnx.py

* style

* fix has_nsfw_concept

* Update convert_stable_diffusion_checkpoint_to_onnx.py

* fix style
parent 3d02c921
...@@ -69,8 +69,15 @@ def onnx_export( ...@@ -69,8 +69,15 @@ def onnx_export(
@torch.no_grad() @torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int): def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
pipeline = StableDiffusionPipeline.from_pretrained(model_path) dtype = torch.float16 if fp16 else torch.float32
if fp16 and torch.cuda.is_available():
device = "cuda"
elif fp16 and not torch.cuda.is_available():
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
else:
device = "cpu"
pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
output_path = Path(output_path) output_path = Path(output_path)
# TEXT ENCODER # TEXT ENCODER
...@@ -84,7 +91,7 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -84,7 +91,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
onnx_export( onnx_export(
pipeline.text_encoder, pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(torch.int32)), model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx", output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"], output_names=["last_hidden_state", "pooler_output"],
...@@ -100,9 +107,9 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -100,9 +107,9 @@ def convert_models(model_path: str, output_path: str, opset: int):
onnx_export( onnx_export(
pipeline.unet, pipeline.unet,
model_args=( model_args=(
torch.randn(2, pipeline.unet.in_channels, 64, 64), torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
torch.LongTensor([0, 1]), torch.LongTensor([0, 1]).to(device=device),
torch.randn(2, 77, 768), torch.randn(2, 77, 768).to(device=device, dtype=dtype),
False, False,
), ),
output_path=unet_path, output_path=unet_path,
...@@ -139,7 +146,7 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -139,7 +146,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export( onnx_export(
vae_encoder, vae_encoder,
model_args=(torch.randn(1, 3, 512, 512), False), model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
output_path=output_path / "vae_encoder" / "model.onnx", output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"], ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"], output_names=["latent_sample"],
...@@ -155,7 +162,7 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -155,7 +162,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
vae_decoder.forward = vae_encoder.decode vae_decoder.forward = vae_encoder.decode
onnx_export( onnx_export(
vae_decoder, vae_decoder,
model_args=(torch.randn(1, 4, 64, 64), False), model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
output_path=output_path / "vae_decoder" / "model.onnx", output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"], ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"], output_names=["sample"],
...@@ -171,13 +178,16 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -171,13 +178,16 @@ def convert_models(model_path: str, output_path: str, opset: int):
safety_checker.forward = safety_checker.forward_onnx safety_checker.forward = safety_checker.forward_onnx
onnx_export( onnx_export(
pipeline.safety_checker, pipeline.safety_checker,
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)), model_args=(
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
),
output_path=output_path / "safety_checker" / "model.onnx", output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"], ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"], output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={ dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
}, },
opset=opset, opset=opset,
) )
...@@ -221,7 +231,8 @@ if __name__ == "__main__": ...@@ -221,7 +231,8 @@ if __name__ == "__main__":
type=int, type=int,
help="The version of the ONNX operator set to use.", help="The version of the ONNX operator set to use.",
) )
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
args = parser.parse_args() args = parser.parse_args()
convert_models(args.model_path, args.output_path, args.opset) convert_models(args.model_path, args.output_path, args.opset, args.fp16)
...@@ -55,7 +55,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -55,7 +55,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -81,6 +83,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -81,6 +83,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if generator is None:
generator = np.random
# get prompt text embeddings # get prompt text embeddings
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
prompt, prompt,
...@@ -98,6 +103,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -98,6 +103,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -133,6 +139,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -133,6 +139,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
return_tensors="np", return_tensors="np",
) )
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
...@@ -140,9 +147,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -140,9 +147,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_shape = (batch_size, 4, height // 8, width // 8) latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None: if latents is None:
latents = np.random.randn(*latents_shape).astype(np.float32) latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape: elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
...@@ -185,13 +193,30 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -185,13 +193,30 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1) image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") if self.safety_checker is not None:
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -121,6 +121,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -121,6 +121,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
...@@ -159,6 +160,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -159,6 +160,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -197,6 +200,9 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -197,6 +200,9 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if generator is None:
generator = np.random
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
...@@ -239,7 +245,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -239,7 +245,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.") raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else: else:
...@@ -257,13 +263,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -257,13 +263,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0] uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]
# duplicate unconditional embeddings for each generation per prompt # duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0) uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
latents_dtype = text_embeddings.dtype
init_image = init_image.astype(latents_dtype)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=init_image)[0] init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
...@@ -297,7 +305,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -297,7 +305,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = np.random.randn(*init_latents.shape).astype(np.float32) noise = generator.randn(*init_latents.shape).astype(latents_dtype)
init_latents = self.scheduler.add_noise( init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
) )
...@@ -341,14 +349,28 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -341,14 +349,28 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1) image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None: if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") safety_checker_input = self.feature_extractor(
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else: else:
has_nsfw_concept = None has_nsfw_concept = None
......
...@@ -23,11 +23,11 @@ NUM_LATENT_CHANNELS = 4 ...@@ -23,11 +23,11 @@ NUM_LATENT_CHANNELS = 4
def prepare_mask_and_masked_image(image, mask, latents_shape): def prepare_mask_and_masked_image(image, mask, latents_shape):
image = np.array(image.convert("RGB")) image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = image.astype(np.float32) / 127.5 - 1.0 image = image.astype(np.float32) / 127.5 - 1.0
image_mask = np.array(mask.convert("L")) image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5) masked_image = image * (image_mask < 127.5)
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST) mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
...@@ -138,6 +138,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -138,6 +138,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -180,6 +181,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -180,6 +181,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`np.ndarray`, *optional*): latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
...@@ -222,6 +225,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -222,6 +225,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if generator is None:
generator = np.random
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
...@@ -261,7 +267,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -261,7 +267,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
...@@ -283,7 +289,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -283,7 +289,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0] uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]
# duplicate unconditional embeddings for each generation per prompt # duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0) uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
...@@ -294,7 +300,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -294,7 +300,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
latents = np.random.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
...@@ -307,6 +313,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -307,6 +313,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
masked_image_latents = self.vae_encoder(sample=masked_image)[0] masked_image_latents = self.vae_encoder(sample=masked_image)[0]
masked_image_latents = 0.18215 * masked_image_latents masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt
mask = mask.repeat(batch_size * num_images_per_prompt, 0)
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)
mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = ( masked_image_latents = (
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
...@@ -367,14 +377,28 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -367,14 +377,28 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1) image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None: if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") safety_checker_input = self.feature_extractor(
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else: else:
has_nsfw_concept = None has_nsfw_concept = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment