Commit 39ac856a authored by comfyanonymous's avatar comfyanonymous
Browse files

ReVision support: unclip nodes can now be used with SDXL.

parent 76d53c46
...@@ -148,13 +148,20 @@ class SDInpaint(BaseModel): ...@@ -148,13 +148,20 @@ class SDInpaint(BaseModel):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
else:
return args["pooled_output"]
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"] clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0) crop_w = kwargs.get("crop_w", 0)
...@@ -178,9 +185,10 @@ class SDXL(BaseModel): ...@@ -178,9 +185,10 @@ class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"] clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0) crop_w = kwargs.get("crop_w", 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment