Commit 391c1046 authored by comfyanonymous's avatar comfyanonymous
Browse files

More flexibility with text encoder return values.

Text encoders can now return other values to the CONDITIONING than the cond
and pooled output.
parent e44fa566
......@@ -130,7 +130,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
......@@ -140,7 +140,15 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
return out
if return_pooled:
return cond, pooled
return cond
......
......@@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
......@@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif outputs[2] is not None:
pooled_output = outputs[2].float()
extra = {}
if self.return_attention_masks:
return z, pooled_output, attention_mask
extra["attention_mask"] = attention_mask
if len(extra) > 0:
return z, pooled_output, extra
return z, pooled_output
......@@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
......@@ -55,8 +55,9 @@ class CLIPTextEncode:
def encode(self, clip, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
class ConditioningCombine:
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment