Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
00c0b2c5
Commit
00c0b2c5
authored
Aug 23, 2023
by
comfyanonymous
Browse files
Initialize text encoder to target dtype.
parent
f081017c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
15 deletions
+29
-15
comfy/ops.py
comfy/ops.py
+11
-2
comfy/sd.py
comfy/sd.py
+5
-2
comfy/sd1_clip.py
comfy/sd1_clip.py
+4
-2
comfy/sd2_clip.py
comfy/sd2_clip.py
+2
-2
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+7
-7
No files found.
comfy/ops.py
View file @
00c0b2c5
...
...
@@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs):
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
@
contextmanager
def
use_comfy_ops
():
# Kind of an ugly hack but I can't think of a better way
def
use_comfy_ops
(
device
=
None
,
dtype
=
None
):
# Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear
=
torch
.
nn
.
Linear
torch
.
nn
.
Linear
=
Linear
force_device
=
device
force_dtype
=
dtype
def
linear_with_dtype
(
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
):
if
force_device
is
not
None
:
device
=
force_device
if
force_dtype
is
not
None
:
dtype
=
force_dtype
return
Linear
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
torch
.
nn
.
Linear
=
linear_with_dtype
try
:
yield
finally
:
...
...
comfy/sd.py
View file @
00c0b2c5
...
...
@@ -545,9 +545,12 @@ class CLIP:
load_device
=
model_management
.
text_encoder_device
()
offload_device
=
model_management
.
text_encoder_offload_device
()
params
[
'device'
]
=
load_device
self
.
cond_stage_model
=
clip
(
**
(
params
))
if
model_management
.
should_use_fp16
(
load_device
):
self
.
cond_stage_model
.
half
()
params
[
'dtype'
]
=
torch
.
float16
else
:
params
[
'dtype'
]
=
torch
.
float32
self
.
cond_stage_model
=
clip
(
**
(
params
))
self
.
tokenizer
=
tokenizer
(
embedding_directory
=
embedding_directory
)
self
.
patcher
=
ModelPatcher
(
self
.
cond_stage_model
,
load_device
=
load_device
,
offload_device
=
offload_device
)
...
...
comfy/sd1_clip.py
View file @
00c0b2c5
...
...
@@ -43,7 +43,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
textmodel_path
=
None
):
# clip-vit-base-patch32
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
textmodel_path
=
None
,
dtype
=
None
):
# clip-vit-base-patch32
super
().
__init__
()
assert
layer
in
self
.
LAYERS
self
.
num_layers
=
12
...
...
@@ -54,10 +54,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd1_clip_config.json"
)
config
=
CLIPTextConfig
.
from_json_file
(
textmodel_json_config
)
self
.
num_layers
=
config
.
num_hidden_layers
with
comfy
.
ops
.
use_comfy_ops
():
with
comfy
.
ops
.
use_comfy_ops
(
device
,
dtype
):
with
modeling_utils
.
no_init_weights
():
self
.
transformer
=
CLIPTextModel
(
config
)
if
dtype
is
not
None
:
self
.
transformer
.
to
(
dtype
)
self
.
max_length
=
max_length
if
freeze
:
self
.
freeze
()
...
...
comfy/sd2_clip.py
View file @
00c0b2c5
...
...
@@ -3,13 +3,13 @@ import torch
import
os
class
SD2ClipModel
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
):
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
,
dtype
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
layer_idx
=
23
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd2_clip_config.json"
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
,
dtype
=
dtype
)
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
def
clip_layer
(
self
,
layer_idx
):
...
...
comfy/sdxl_clip.py
View file @
00c0b2c5
...
...
@@ -3,13 +3,13 @@ import torch
import
os
class
SDXLClipG
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
):
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
,
dtype
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
layer_idx
=-
2
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"clip_config_bigg.json"
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
,
dtype
=
dtype
)
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
self
.
text_projection
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
1280
,
1280
))
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
...
...
@@ -42,11 +42,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
return
self
.
clip_g
.
untokenize
(
token_weight_pair
)
class
SDXLClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
=
"cpu"
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
()
self
.
clip_l
=
sd1_clip
.
SD1ClipModel
(
layer
=
"hidden"
,
layer_idx
=
11
,
device
=
device
)
self
.
clip_l
=
sd1_clip
.
SD1ClipModel
(
layer
=
"hidden"
,
layer_idx
=
11
,
device
=
device
,
dtype
=
dtype
)
self
.
clip_l
.
layer_norm_hidden_state
=
False
self
.
clip_g
=
SDXLClipG
(
device
=
device
)
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
def
clip_layer
(
self
,
layer_idx
):
self
.
clip_l
.
clip_layer
(
layer_idx
)
...
...
@@ -70,9 +70,9 @@ class SDXLClipModel(torch.nn.Module):
return
self
.
clip_l
.
load_sd
(
sd
)
class
SDXLRefinerClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
=
"cpu"
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
()
self
.
clip_g
=
SDXLClipG
(
device
=
device
)
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
def
clip_layer
(
self
,
layer_idx
):
self
.
clip_g
.
clip_layer
(
layer_idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment