"tests/vscode:/vscode.git/clone" did not exist on "7b98c4cc67b7131724f1cb5315da1c01387c6667"
Commit 391c1046 authored by comfyanonymous's avatar comfyanonymous
Browse files

More flexibility with text encoder return values.

Text encoders can now return other values to the CONDITIONING than the cond
and pooled output.
parent e44fa566
...@@ -130,7 +130,7 @@ class CLIP: ...@@ -130,7 +130,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False): def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options() self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None: if self.layer_idx is not None:
...@@ -140,7 +140,15 @@ class CLIP: ...@@ -140,7 +140,15 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
return out
if return_pooled: if return_pooled:
return cond, pooled return cond, pooled
return cond return cond
......
...@@ -62,7 +62,16 @@ class ClipTokenWeightEncoder: ...@@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
r = (out[-1:].to(model_management.intermediate_device()), first_pooled) r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else: else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
return r return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...@@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif outputs[2] is not None: elif outputs[2] is not None:
pooled_output = outputs[2].float() pooled_output = outputs[2].float()
extra = {}
if self.return_attention_masks: if self.return_attention_masks:
return z, pooled_output, attention_mask extra["attention_mask"] = attention_mask
if len(extra) > 0:
return z, pooled_output, extra
return z, pooled_output return z, pooled_output
...@@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module): ...@@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name] token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled return out
def load_sd(self, sd): def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd) return getattr(self, self.clip).load_sd(sd)
...@@ -55,8 +55,9 @@ class CLIPTextEncode: ...@@ -55,8 +55,9 @@ class CLIPTextEncode:
def encode(self, clip, text): def encode(self, clip, text):
tokens = clip.tokenize(text) tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
return ([[cond, {"pooled_output": pooled}]], ) cond = output.pop("cond")
return ([[cond, output]], )
class ConditioningCombine: class ConditioningCombine:
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment