Unverified Commit 7547f9b4 authored by dymil's avatar dymil Committed by GitHub
Browse files

Fix timestep dtype in legacy inpaint (#2120)

* Fix timestep dtype in legacy inpaint

This matches the structure in the text2img, img2img, and inpaint ONNX pipelines

* Fix style in dtype patch
parent a87e87fc
......@@ -10,7 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ..onnx_utils import OnnxRuntimeModel
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
......@@ -391,6 +391,10 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].numpy()
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
......@@ -398,9 +402,10 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=prompt_embeds
)[0]
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
0
]
# perform guidance
if do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment