"deploy/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "1a2c400750a73bf7fd9381efdc89c0152f7aa4cf"
Commit da115bd7 authored by BlenderNeko's avatar BlenderNeko
Browse files

ensure backwards compat with optional args

parent 752f7a16
......@@ -372,12 +372,16 @@ class CLIP:
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def tokenize(self, text):
return self.tokenizer.tokenize_with_weights(text)
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode(self, tokens):
def encode(self, text, from_tokens=False):
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
if from_tokens:
tokens = text
else:
tokens = self.tokenizer.tokenize_with_weights(text)
try:
self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens)
......
......@@ -240,7 +240,7 @@ class SD1Tokenizer:
return (embed, "")
def tokenize_with_weights(self, text:str):
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
......@@ -301,6 +301,10 @@ class SD1Tokenizer:
#add start and end tokens
batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens]
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
......
......@@ -44,8 +44,7 @@ class CLIPTextEncode:
CATEGORY = "conditioning"
def encode(self, clip, text):
tokens = clip.tokenize(text)
return ([[clip.encode(tokens), {}]], )
return ([[clip.encode(text), {}]], )
class ConditioningCombine:
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment