"examples/pytorch/vscode:/vscode.git/clone" did not exist on "a3ea4873d343d8876d961621ca3bfb889ec78501"
Unverified Commit 34bfe98e authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Gligen Text to Image fix (#5010)

* fix gligen clip import issue

* fix dtype issue with gligen text to image pipeline

* make fix copies
parent b47f5115
......@@ -197,6 +197,7 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
......@@ -530,6 +531,7 @@ if TYPE_CHECKING:
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
CLIPImageProjection,
CycleDiffusionPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
......
......@@ -113,6 +113,7 @@ else:
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend(
[
"CLIPImageProjection",
"CycleDiffusionPipeline",
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionDepth2ImgPipeline",
......@@ -323,6 +324,7 @@ if TYPE_CHECKING:
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
CLIPImageProjection,
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionDepth2ImgPipeline,
......
......@@ -582,6 +582,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
if input is None:
return None
inputs = self.processor(images=[input], return_tensors="pt").to(device)
inputs["pixel_values"] = inputs["pixel_values"].to(self.image_encoder.dtype)
outputs = self.image_encoder(**inputs)
feature = outputs.image_embeds
feature = self.image_project(feature).squeeze(0)
......
......@@ -92,6 +92,21 @@ class AudioLDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment