Commit fbdb14d4 authored by comfyanonymous's avatar comfyanonymous
Browse files

Cleaner CLIP text encoder implementation.

Use a simple CLIP model implementation instead of the one from
transformers.

This will allow some interesting things that would too hackish to implement
using the transformers implementation.
parent 2db86b46
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPMLP(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
super().__init__()
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
def forward(self, x, mask=None, optimized_attention=None):
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=True)
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
#TODO: attention_mask
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
return self.text_model(*args, **kwargs)
...@@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None): ...@@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None):
del q, k del q, k
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') if mask.dtype == torch.bool:
max_neg_value = -torch.finfo(sim.dtype).max mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
mask = repeat(mask, 'b j -> (b h) () j', h=h) max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
sim += mask
# attention, what we cannot get enough of # attention, what we cannot get enough of
sim = sim.softmax(dim=-1) sim = sim.softmax(dim=-1)
...@@ -340,6 +343,18 @@ else: ...@@ -340,6 +343,18 @@ else:
if model_management.pytorch_attention_enabled(): if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False):
if device == torch.device("cpu"): #TODO
if model_management.pytorch_attention_enabled():
return attention_pytorch
else:
return attention_basic
if mask:
return optimized_attention_masked
return optimized_attention
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
......
import os import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils from transformers import CLIPTokenizer
import comfy.ops import comfy.ops
import torch import torch
import traceback import traceback
import zipfile import zipfile
from . import model_management from . import model_management
import contextlib import contextlib
import comfy.clip_model
import json
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
...@@ -65,35 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -65,35 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None: if textmodel_json_config is None:
self.transformer = model_class.from_pretrained(textmodel_path) textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
else:
if textmodel_json_config is None: with open(textmodel_json_config) as f:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") config = json.load(f)
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.transformer = model_class(config, dtype, device, comfy.ops)
with comfy.ops.use_comfy_ops(device, dtype): self.num_layers = self.transformer.num_layers
with modeling_utils.no_init_weights():
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
embeddings_bak = inner_model.embeddings.to(torch.float32)
inner_model.embeddings = None
self.transformer.to(dtype)
inner_model.embeddings = embeddings_bak
else:
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
self.transformer.to(dtype)
self.transformer.set_input_embeddings(previous_inputs)
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
...@@ -108,7 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -108,7 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx) self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx) self.layer_default = (self.layer, self.layer_idx)
...@@ -119,7 +105,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -119,7 +105,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False param.requires_grad = False
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
if abs(layer_idx) >= self.num_layers: if abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
else: else:
self.layer = "hidden" self.layer = "hidden"
...@@ -174,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -174,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: if self.transformer.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a) precision_scope = lambda a, dtype: contextlib.nullcontext(a)
...@@ -190,20 +176,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -190,20 +176,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if tokens[x, y] == max_token: if tokens[x, y] == max_token:
break break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds) self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last": if self.layer == "last":
z = outputs.last_hidden_state z = outputs[0]
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else: else:
z = outputs.hidden_states[self.layer_idx] z = outputs[1]
if self.layer_norm_hidden_state:
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
if hasattr(outputs, "pooler_output"): if outputs[2] is not None:
pooled_output = outputs.pooler_output.float() pooled_output = outputs[2].float()
else: else:
pooled_output = None pooled_output = None
......
...@@ -3,13 +3,13 @@ import torch ...@@ -3,13 +3,13 @@ import torch
import os import os
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=23 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):
......
...@@ -3,13 +3,13 @@ import torch ...@@ -3,13 +3,13 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd): def load_sd(self, sd):
...@@ -37,7 +37,7 @@ class SDXLTokenizer: ...@@ -37,7 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment