Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
2c038cce
Commit
2c038cce
authored
Jul 31, 2024
by
comfyanonymous
Browse files
Lower CLIP memory usage by a bit.
parent
b85216a3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
25 deletions
+28
-25
comfy/clip_model.py
comfy/clip_model.py
+12
-11
comfy/sd1_clip.py
comfy/sd1_clip.py
+4
-3
comfy/text_encoders/bert.py
comfy/text_encoders/bert.py
+11
-10
comfy/text_encoders/t5.py
comfy/text_encoders/t5.py
+1
-1
No files found.
comfy/clip_model.py
View file @
2c038cce
import
torch
from
comfy.ldm.modules.attention
import
optimized_attention_for_device
import
comfy.ops
class
CLIPAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
heads
,
dtype
,
device
,
operations
):
...
...
@@ -71,13 +72,13 @@ class CLIPEncoder(torch.nn.Module):
return
x
,
intermediate
class
CLIPEmbeddings
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
=
49408
,
num_positions
=
77
,
dtype
=
None
,
device
=
None
):
def
__init__
(
self
,
embed_dim
,
vocab_size
=
49408
,
num_positions
=
77
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
token_embedding
=
torch
.
nn
.
Embedding
(
vocab_size
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
num_positions
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
token_embedding
=
operations
.
Embedding
(
vocab_size
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embedding
=
operations
.
Embedding
(
num_positions
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_tokens
):
return
self
.
token_embedding
(
input_tokens
)
+
self
.
position_embedding
.
weight
def
forward
(
self
,
input_tokens
,
dtype
=
torch
.
float32
):
return
self
.
token_embedding
(
input_tokens
,
out_dtype
=
dtype
)
+
comfy
.
ops
.
cast_to
(
self
.
position_embedding
.
weight
,
dtype
=
dtype
,
device
=
input_tokens
.
device
)
class
CLIPTextModel_
(
torch
.
nn
.
Module
):
...
...
@@ -90,12 +91,12 @@ class CLIPTextModel_(torch.nn.Module):
self
.
eos_token_id
=
config_dict
[
"eos_token_id"
]
super
().
__init__
()
self
.
embeddings
=
CLIPEmbeddings
(
embed_dim
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
embeddings
=
CLIPEmbeddings
(
embed_dim
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
encoder
=
CLIPEncoder
(
num_layers
,
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
,
operations
)
self
.
final_layer_norm
=
operations
.
LayerNorm
(
embed_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_tokens
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
x
=
self
.
embeddings
(
input_tokens
)
def
forward
(
self
,
input_tokens
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
,
dtype
=
torch
.
float32
):
x
=
self
.
embeddings
(
input_tokens
,
dtype
=
dtype
)
mask
=
None
if
attention_mask
is
not
None
:
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
1
,
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
...
...
@@ -154,11 +155,11 @@ class CLIPVisionEmbeddings(torch.nn.Module):
num_patches
=
(
image_size
//
patch_size
)
**
2
num_positions
=
num_patches
+
1
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
num_positions
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embedding
=
operations
.
Embedding
(
num_positions
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
pixel_values
):
embeds
=
self
.
patch_embedding
(
pixel_values
).
flatten
(
2
).
transpose
(
1
,
2
)
return
torch
.
cat
([
self
.
class_embedding
.
to
(
embeds
.
device
).
expand
(
pixel_values
.
shape
[
0
],
1
,
-
1
),
embeds
],
dim
=
1
)
+
self
.
position_embedding
.
weight
.
to
(
embeds
.
device
)
return
torch
.
cat
([
comfy
.
ops
.
cast_to_input
(
self
.
class_embedding
,
embeds
).
expand
(
pixel_values
.
shape
[
0
],
1
,
-
1
),
embeds
],
dim
=
1
)
+
comfy
.
ops
.
cast_to_input
(
self
.
position_embedding
.
weight
,
embeds
)
class
CLIPVision
(
torch
.
nn
.
Module
):
...
...
@@ -170,7 +171,7 @@ class CLIPVision(torch.nn.Module):
intermediate_size
=
config_dict
[
"intermediate_size"
]
intermediate_activation
=
config_dict
[
"hidden_act"
]
self
.
embeddings
=
CLIPVisionEmbeddings
(
embed_dim
,
config_dict
[
"num_channels"
],
config_dict
[
"patch_size"
],
config_dict
[
"image_size"
],
dtype
=
torch
.
float32
,
device
=
device
,
operations
=
operations
)
self
.
embeddings
=
CLIPVisionEmbeddings
(
embed_dim
,
config_dict
[
"num_channels"
],
config_dict
[
"patch_size"
],
config_dict
[
"image_size"
],
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
pre_layrnorm
=
operations
.
LayerNorm
(
embed_dim
)
self
.
encoder
=
CLIPEncoder
(
num_layers
,
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
,
operations
)
self
.
post_layernorm
=
operations
.
LayerNorm
(
embed_dim
)
...
...
comfy/sd1_clip.py
View file @
2c038cce
...
...
@@ -94,7 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with
open
(
textmodel_json_config
)
as
f
:
config
=
json
.
load
(
f
)
self
.
transformer
=
model_class
(
config
,
dtype
,
device
,
comfy
.
ops
.
manual_cast
)
self
.
operations
=
comfy
.
ops
.
manual_cast
self
.
transformer
=
model_class
(
config
,
dtype
,
device
,
self
.
operations
)
self
.
num_layers
=
self
.
transformer
.
num_layers
self
.
max_length
=
max_length
...
...
@@ -161,7 +162,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
n
=
token_dict_size
if
len
(
embedding_weights
)
>
0
:
new_embedding
=
torch
.
nn
.
Embedding
(
next_new_token
+
1
,
current_embeds
.
weight
.
shape
[
1
],
device
=
current_embeds
.
weight
.
device
,
dtype
=
current_embeds
.
weight
.
dtype
)
new_embedding
=
self
.
operations
.
Embedding
(
next_new_token
+
1
,
current_embeds
.
weight
.
shape
[
1
],
device
=
current_embeds
.
weight
.
device
,
dtype
=
current_embeds
.
weight
.
dtype
)
new_embedding
.
weight
[:
token_dict_size
]
=
current_embeds
.
weight
for
x
in
embedding_weights
:
new_embedding
.
weight
[
n
]
=
x
...
...
@@ -194,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if
self
.
enable_attention_masks
:
attention_mask_model
=
attention_mask
outputs
=
self
.
transformer
(
tokens
,
attention_mask_model
,
intermediate_output
=
self
.
layer_idx
,
final_layer_norm_intermediate
=
self
.
layer_norm_hidden_state
)
outputs
=
self
.
transformer
(
tokens
,
attention_mask_model
,
intermediate_output
=
self
.
layer_idx
,
final_layer_norm_intermediate
=
self
.
layer_norm_hidden_state
,
dtype
=
torch
.
float32
)
self
.
transformer
.
set_input_embeddings
(
backup_embeds
)
if
self
.
layer
==
"last"
:
...
...
comfy/text_encoders/bert.py
View file @
2c038cce
import
torch
from
comfy.ldm.modules.attention
import
optimized_attention_for_device
import
comfy.ops
class
BertAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
heads
,
dtype
,
device
,
operations
):
...
...
@@ -86,19 +87,19 @@ class BertEncoder(torch.nn.Module):
class
BertEmbeddings
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
pad_token_id
,
embed_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
word_embeddings
=
torch
.
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
pad_token_id
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
token_type_embeddings
=
torch
.
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
word_embeddings
=
operations
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
pad_token_id
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embeddings
=
operations
.
Embedding
(
max_position_embeddings
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
token_type_embeddings
=
operations
.
Embedding
(
type_vocab_size
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
LayerNorm
=
operations
.
LayerNorm
(
embed_dim
,
eps
=
layer_norm_eps
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_tokens
,
token_type_ids
=
None
):
x
=
self
.
word_embeddings
(
input_tokens
)
x
+=
self
.
position_embeddings
.
weight
[:
x
.
shape
[
1
]]
def
forward
(
self
,
input_tokens
,
token_type_ids
=
None
,
dtype
=
None
):
x
=
self
.
word_embeddings
(
input_tokens
,
out_dtype
=
dtype
)
x
+=
comfy
.
ops
.
cast_to_input
(
self
.
position_embeddings
.
weight
[:
x
.
shape
[
1
]]
,
x
)
if
token_type_ids
is
not
None
:
x
+=
self
.
token_type_embeddings
(
token_type_ids
)
x
+=
self
.
token_type_embeddings
(
token_type_ids
,
out_dtype
=
x
.
dtype
)
else
:
x
+=
self
.
token_type_embeddings
.
weight
[
0
]
x
+=
comfy
.
ops
.
cast_to_input
(
self
.
token_type_embeddings
.
weight
[
0
]
,
x
)
x
=
self
.
LayerNorm
(
x
)
return
x
...
...
@@ -112,8 +113,8 @@ class BertModel_(torch.nn.Module):
self
.
embeddings
=
BertEmbeddings
(
config_dict
[
"vocab_size"
],
config_dict
[
"max_position_embeddings"
],
config_dict
[
"type_vocab_size"
],
config_dict
[
"pad_token_id"
],
embed_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
)
self
.
encoder
=
BertEncoder
(
config_dict
[
"num_hidden_layers"
],
embed_dim
,
config_dict
[
"intermediate_size"
],
config_dict
[
"num_attention_heads"
],
layer_norm_eps
,
dtype
,
device
,
operations
)
def
forward
(
self
,
input_tokens
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
x
=
self
.
embeddings
(
input_tokens
)
def
forward
(
self
,
input_tokens
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
,
dtype
=
None
):
x
=
self
.
embeddings
(
input_tokens
,
dtype
=
dtype
)
mask
=
None
if
attention_mask
is
not
None
:
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
1
,
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
...
...
comfy/text_encoders/t5.py
View file @
2c038cce
...
...
@@ -200,7 +200,7 @@ class T5Stack(torch.nn.Module):
self
.
final_layer_norm
=
T5LayerNorm
(
model_dim
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# self.dropout = nn.Dropout(config.dropout_rate)
def
forward
(
self
,
x
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
def
forward
(
self
,
x
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
,
dtype
=
None
):
mask
=
None
if
attention_mask
is
not
None
:
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
1
,
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment