Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
174eba8e
Commit
174eba8e
authored
Dec 09, 2023
by
comfyanonymous
Browse files
Use own clip vision model implementation.
parent
97015b6b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
22 deletions
+81
-22
comfy/clip_model.py
comfy/clip_model.py
+64
-6
comfy/clip_vision.py
comfy/clip_vision.py
+17
-16
No files found.
comfy/clip_model.py
View file @
174eba8e
...
@@ -57,12 +57,7 @@ class CLIPEncoder(torch.nn.Module):
...
@@ -57,12 +57,7 @@ class CLIPEncoder(torch.nn.Module):
self
.
layers
=
torch
.
nn
.
ModuleList
([
CLIPLayer
(
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
,
operations
)
for
i
in
range
(
num_layers
)])
self
.
layers
=
torch
.
nn
.
ModuleList
([
CLIPLayer
(
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
,
operations
)
for
i
in
range
(
num_layers
)])
def
forward
(
self
,
x
,
mask
=
None
,
intermediate_output
=
None
):
def
forward
(
self
,
x
,
mask
=
None
,
intermediate_output
=
None
):
optimized_attention
=
optimized_attention_for_device
(
x
.
device
,
mask
=
True
)
optimized_attention
=
optimized_attention_for_device
(
x
.
device
,
mask
=
mask
is
not
None
)
causal_mask
=
torch
.
empty
(
x
.
shape
[
1
],
x
.
shape
[
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
).
fill_
(
float
(
"-inf"
)).
triu_
(
1
)
if
mask
is
not
None
:
mask
+=
causal_mask
else
:
mask
=
causal_mask
if
intermediate_output
is
not
None
:
if
intermediate_output
is
not
None
:
if
intermediate_output
<
0
:
if
intermediate_output
<
0
:
...
@@ -105,6 +100,12 @@ class CLIPTextModel_(torch.nn.Module):
...
@@ -105,6 +100,12 @@ class CLIPTextModel_(torch.nn.Module):
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
unsqueeze
(
1
).
unsqueeze
(
1
).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
unsqueeze
(
1
).
unsqueeze
(
1
).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
float
(
"-inf"
))
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
float
(
"-inf"
))
causal_mask
=
torch
.
empty
(
x
.
shape
[
1
],
x
.
shape
[
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
).
fill_
(
float
(
"-inf"
)).
triu_
(
1
)
if
mask
is
not
None
:
mask
+=
causal_mask
else
:
mask
=
causal_mask
x
,
i
=
self
.
encoder
(
x
,
mask
=
mask
,
intermediate_output
=
intermediate_output
)
x
,
i
=
self
.
encoder
(
x
,
mask
=
mask
,
intermediate_output
=
intermediate_output
)
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
final_layer_norm
(
x
)
if
i
is
not
None
and
final_layer_norm_intermediate
:
if
i
is
not
None
and
final_layer_norm_intermediate
:
...
@@ -128,3 +129,60 @@ class CLIPTextModel(torch.nn.Module):
...
@@ -128,3 +129,60 @@ class CLIPTextModel(torch.nn.Module):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
text_model
(
*
args
,
**
kwargs
)
return
self
.
text_model
(
*
args
,
**
kwargs
)
class
CLIPVisionEmbeddings
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_channels
=
3
,
patch_size
=
14
,
image_size
=
224
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
class_embedding
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
embed_dim
,
dtype
=
dtype
,
device
=
device
))
self
.
patch_embedding
=
operations
.
Conv2d
(
in_channels
=
num_channels
,
out_channels
=
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
num_patches
=
(
image_size
//
patch_size
)
**
2
num_positions
=
num_patches
+
1
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
num_positions
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
pixel_values
):
embeds
=
self
.
patch_embedding
(
pixel_values
).
flatten
(
2
).
transpose
(
1
,
2
)
return
torch
.
cat
([
self
.
class_embedding
.
expand
(
pixel_values
.
shape
[
0
],
1
,
-
1
),
embeds
],
dim
=
1
)
+
self
.
position_embedding
.
weight
class
CLIPVision
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
,
operations
):
super
().
__init__
()
num_layers
=
config_dict
[
"num_hidden_layers"
]
embed_dim
=
config_dict
[
"hidden_size"
]
heads
=
config_dict
[
"num_attention_heads"
]
intermediate_size
=
config_dict
[
"intermediate_size"
]
intermediate_activation
=
config_dict
[
"hidden_act"
]
self
.
embeddings
=
CLIPVisionEmbeddings
(
embed_dim
,
config_dict
[
"num_channels"
],
config_dict
[
"patch_size"
],
config_dict
[
"image_size"
],
dtype
=
torch
.
float32
,
device
=
device
,
operations
=
operations
)
self
.
pre_layrnorm
=
operations
.
LayerNorm
(
embed_dim
)
self
.
encoder
=
CLIPEncoder
(
num_layers
,
embed_dim
,
heads
,
intermediate_size
,
intermediate_activation
,
dtype
,
device
,
operations
)
self
.
post_layernorm
=
operations
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
pixel_values
,
attention_mask
=
None
,
intermediate_output
=
None
):
x
=
self
.
embeddings
(
pixel_values
)
x
=
self
.
pre_layrnorm
(
x
)
#TODO: attention_mask?
x
,
i
=
self
.
encoder
(
x
,
mask
=
None
,
intermediate_output
=
intermediate_output
)
pooled_output
=
self
.
post_layernorm
(
x
[:,
0
,
:])
return
x
,
i
,
pooled_output
class
CLIPVisionModelProjection
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
vision_model
=
CLIPVision
(
config_dict
,
dtype
,
device
,
operations
)
self
.
visual_projection
=
operations
.
Linear
(
config_dict
[
"hidden_size"
],
config_dict
[
"projection_dim"
],
bias
=
False
)
def
forward
(
self
,
*
args
,
**
kwargs
):
x
=
self
.
vision_model
(
*
args
,
**
kwargs
)
out
=
self
.
visual_projection
(
x
[
2
])
return
(
x
[
0
],
x
[
1
],
out
)
comfy/clip_vision.py
View file @
174eba8e
from
transformers
import
CLIPVisionModelWithProjection
,
CLIPVisionConfig
,
modeling_utils
from
.utils
import
load_torch_file
,
transformers_convert
,
common_upscale
from
.utils
import
load_torch_file
,
transformers_convert
,
common_upscale
import
os
import
os
import
torch
import
torch
import
contextlib
import
contextlib
import
json
import
comfy.ops
import
comfy.ops
import
comfy.model_patcher
import
comfy.model_patcher
import
comfy.model_management
import
comfy.model_management
import
comfy.utils
import
comfy.utils
import
comfy.clip_model
class
Output
:
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
item
):
setattr
(
self
,
key
,
item
)
def
clip_preprocess
(
image
,
size
=
224
):
def
clip_preprocess
(
image
,
size
=
224
):
mean
=
torch
.
tensor
([
0.48145466
,
0.4578275
,
0.40821073
],
device
=
image
.
device
,
dtype
=
image
.
dtype
)
mean
=
torch
.
tensor
([
0.48145466
,
0.4578275
,
0.40821073
],
device
=
image
.
device
,
dtype
=
image
.
dtype
)
...
@@ -22,17 +29,16 @@ def clip_preprocess(image, size=224):
...
@@ -22,17 +29,16 @@ def clip_preprocess(image, size=224):
class
ClipVisionModel
():
class
ClipVisionModel
():
def
__init__
(
self
,
json_config
):
def
__init__
(
self
,
json_config
):
config
=
CLIPVisionConfig
.
from_json_file
(
json_config
)
with
open
(
json_config
)
as
f
:
config
=
json
.
load
(
f
)
self
.
load_device
=
comfy
.
model_management
.
text_encoder_device
()
self
.
load_device
=
comfy
.
model_management
.
text_encoder_device
()
offload_device
=
comfy
.
model_management
.
text_encoder_offload_device
()
offload_device
=
comfy
.
model_management
.
text_encoder_offload_device
()
self
.
dtype
=
torch
.
float32
self
.
dtype
=
torch
.
float32
if
comfy
.
model_management
.
should_use_fp16
(
self
.
load_device
,
prioritize_performance
=
False
):
if
comfy
.
model_management
.
should_use_fp16
(
self
.
load_device
,
prioritize_performance
=
False
):
self
.
dtype
=
torch
.
float16
self
.
dtype
=
torch
.
float16
with
comfy
.
ops
.
use_comfy_ops
(
offload_device
,
self
.
dtype
):
self
.
model
=
comfy
.
clip_model
.
CLIPVisionModelProjection
(
config
,
self
.
dtype
,
offload_device
,
comfy
.
ops
)
with
modeling_utils
.
no_init_weights
():
self
.
model
=
CLIPVisionModelWithProjection
(
config
)
self
.
model
.
to
(
self
.
dtype
)
self
.
patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
model
,
load_device
=
self
.
load_device
,
offload_device
=
offload_device
)
self
.
patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
model
,
load_device
=
self
.
load_device
,
offload_device
=
offload_device
)
def
load_sd
(
self
,
sd
):
def
load_sd
(
self
,
sd
):
...
@@ -48,17 +54,12 @@ class ClipVisionModel():
...
@@ -48,17 +54,12 @@ class ClipVisionModel():
precision_scope
=
lambda
a
,
b
:
contextlib
.
nullcontext
(
a
)
precision_scope
=
lambda
a
,
b
:
contextlib
.
nullcontext
(
a
)
with
precision_scope
(
comfy
.
model_management
.
get_autocast_device
(
self
.
load_device
),
torch
.
float32
):
with
precision_scope
(
comfy
.
model_management
.
get_autocast_device
(
self
.
load_device
),
torch
.
float32
):
outputs
=
self
.
model
(
pixel_values
=
pixel_values
,
output_hidden_states
=
True
)
out
=
self
.
model
(
pixel_values
=
pixel_values
,
intermediate_output
=-
2
)
for
k
in
outputs
:
t
=
outputs
[
k
]
if
t
is
not
None
:
if
k
==
'hidden_states'
:
outputs
[
"penultimate_hidden_states"
]
=
t
[
-
2
].
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
[
"hidden_states"
]
=
None
else
:
outputs
[
k
]
=
t
.
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
=
Output
()
outputs
[
"last_hidden_state"
]
=
out
[
0
].
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
[
"image_embeds"
]
=
out
[
2
].
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
[
"penultimate_hidden_states"
]
=
out
[
1
].
to
(
comfy
.
model_management
.
intermediate_device
())
return
outputs
return
outputs
def
convert_to_transformers
(
sd
,
prefix
):
def
convert_to_transformers
(
sd
,
prefix
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment