Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
0e49211a
Commit
0e49211a
authored
Jun 11, 2024
by
comfyanonymous
Browse files
Load the SD3 T5xxl model in the same dtype stored in the checkpoint.
parent
5889b7ca
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
49 additions
and
6 deletions
+49
-6
comfy/model_management.py
comfy/model_management.py
+17
-0
comfy/sd.py
comfy/sd.py
+7
-1
comfy/sd1_clip.py
comfy/sd1_clip.py
+4
-0
comfy/sd3_clip.py
comfy/sd3_clip.py
+15
-3
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+1
-0
comfy/supported_models.py
comfy/supported_models.py
+5
-2
No files found.
comfy/model_management.py
View file @
0e49211a
...
@@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO
...
@@ -639,6 +639,23 @@ def supports_dtype(device, dtype): #TODO
return
True
return
True
return
False
return
False
def
supports_cast
(
device
,
dtype
):
#TODO
if
dtype
==
torch
.
float32
:
return
True
if
dtype
==
torch
.
float16
:
return
True
if
is_device_mps
(
device
):
return
False
if
directml_enabled
:
#TODO: test this
return
False
if
dtype
==
torch
.
bfloat16
:
return
True
if
dtype
==
torch
.
float8_e4m3fn
:
return
True
if
dtype
==
torch
.
float8_e5m2
:
return
True
return
False
def
device_supports_non_blocking
(
device
):
def
device_supports_non_blocking
(
device
):
if
is_device_mps
(
device
):
if
is_device_mps
(
device
):
return
False
#pytorch bug? mps doesn't support non blocking
return
False
#pytorch bug? mps doesn't support non blocking
...
...
comfy/sd.py
View file @
0e49211a
...
@@ -98,13 +98,19 @@ class CLIP:
...
@@ -98,13 +98,19 @@ class CLIP:
load_device
=
model_management
.
text_encoder_device
()
load_device
=
model_management
.
text_encoder_device
()
offload_device
=
model_management
.
text_encoder_offload_device
()
offload_device
=
model_management
.
text_encoder_offload_device
()
params
[
'device'
]
=
offload_device
params
[
'device'
]
=
offload_device
params
[
'dtype'
]
=
model_management
.
text_encoder_dtype
(
load_device
)
dtype
=
model_management
.
text_encoder_dtype
(
load_device
)
params
[
'dtype'
]
=
dtype
self
.
cond_stage_model
=
clip
(
**
(
params
))
self
.
cond_stage_model
=
clip
(
**
(
params
))
for
dt
in
self
.
cond_stage_model
.
dtypes
:
if
not
model_management
.
supports_cast
(
load_device
,
dt
):
load_device
=
offload_device
self
.
tokenizer
=
tokenizer
(
embedding_directory
=
embedding_directory
)
self
.
tokenizer
=
tokenizer
(
embedding_directory
=
embedding_directory
)
self
.
patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
cond_stage_model
,
load_device
=
load_device
,
offload_device
=
offload_device
)
self
.
patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
cond_stage_model
,
load_device
=
load_device
,
offload_device
=
offload_device
)
self
.
layer_idx
=
None
self
.
layer_idx
=
None
logging
.
debug
(
"CLIP model load device: {}, offload device: {}"
.
format
(
load_device
,
offload_device
))
def
clone
(
self
):
def
clone
(
self
):
n
=
CLIP
(
no_init
=
True
)
n
=
CLIP
(
no_init
=
True
)
...
...
comfy/sd1_clip.py
View file @
0e49211a
...
@@ -511,6 +511,10 @@ class SD1ClipModel(torch.nn.Module):
...
@@ -511,6 +511,10 @@ class SD1ClipModel(torch.nn.Module):
self
.
clip
=
"clip_{}"
.
format
(
self
.
clip_name
)
self
.
clip
=
"clip_{}"
.
format
(
self
.
clip_name
)
setattr
(
self
,
self
.
clip
,
clip_model
(
device
=
device
,
dtype
=
dtype
,
**
kwargs
))
setattr
(
self
,
self
.
clip
,
clip_model
(
device
=
device
,
dtype
=
dtype
,
**
kwargs
))
self
.
dtypes
=
set
()
if
dtype
is
not
None
:
self
.
dtypes
.
add
(
dtype
)
def
set_clip_options
(
self
,
options
):
def
set_clip_options
(
self
,
options
):
getattr
(
self
,
self
.
clip
).
set_clip_options
(
options
)
getattr
(
self
,
self
.
clip
).
set_clip_options
(
options
)
...
...
comfy/sd3_clip.py
View file @
0e49211a
...
@@ -44,24 +44,36 @@ class SD3Tokenizer:
...
@@ -44,24 +44,36 @@ class SD3Tokenizer:
return
self
.
clip_g
.
untokenize
(
token_weight_pair
)
return
self
.
clip_g
.
untokenize
(
token_weight_pair
)
class
SD3ClipModel
(
torch
.
nn
.
Module
):
class
SD3ClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
clip_l
=
True
,
clip_g
=
True
,
t5
=
True
,
device
=
"cpu"
,
dtype
=
None
):
def
__init__
(
self
,
clip_l
=
True
,
clip_g
=
True
,
t5
=
True
,
dtype_t5
=
None
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
dtypes
=
set
()
if
clip_l
:
if
clip_l
:
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
,
return_projected_pooled
=
False
)
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
,
return_projected_pooled
=
False
)
self
.
dtypes
.
add
(
dtype
)
else
:
else
:
self
.
clip_l
=
None
self
.
clip_l
=
None
if
clip_g
:
if
clip_g
:
self
.
clip_g
=
sdxl_clip
.
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
self
.
clip_g
=
sdxl_clip
.
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
self
.
dtypes
.
add
(
dtype
)
else
:
else
:
self
.
clip_g
=
None
self
.
clip_g
=
None
if
t5
:
if
t5
:
self
.
t5xxl
=
T5XXLModel
(
device
=
device
,
dtype
=
dtype
)
if
dtype_t5
is
None
:
dtype_t5
=
dtype
elif
comfy
.
model_management
.
dtype_size
(
dtype_t5
)
>
comfy
.
model_management
.
dtype_size
(
dtype
):
dtype_t5
=
dtype
if
not
comfy
.
model_management
.
supports_cast
(
device
,
dtype_t5
):
dtype_t5
=
dtype
self
.
t5xxl
=
T5XXLModel
(
device
=
device
,
dtype
=
dtype_t5
)
self
.
dtypes
.
add
(
dtype_t5
)
else
:
else
:
self
.
t5xxl
=
None
self
.
t5xxl
=
None
logging
.
debug
(
"Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}"
.
format
(
clip_l
,
clip_g
,
t5
))
logging
.
debug
(
"Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl
{}:
{}"
.
format
(
clip_l
,
clip_g
,
t5
,
dtype_
t5
))
def
set_clip_options
(
self
,
options
):
def
set_clip_options
(
self
,
options
):
if
self
.
clip_l
is
not
None
:
if
self
.
clip_l
is
not
None
:
...
...
comfy/sdxl_clip.py
View file @
0e49211a
...
@@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module):
...
@@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
)
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
)
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
self
.
dtypes
=
set
([
dtype
])
def
set_clip_options
(
self
,
options
):
def
set_clip_options
(
self
,
options
):
self
.
clip_l
.
set_clip_options
(
options
)
self
.
clip_l
.
set_clip_options
(
options
)
...
...
comfy/supported_models.py
View file @
0e49211a
...
@@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE):
...
@@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE):
clip_l
=
False
clip_l
=
False
clip_g
=
False
clip_g
=
False
t5
=
False
t5
=
False
dtype_t5
=
None
pref
=
self
.
text_encoder_key_prefix
[
0
]
pref
=
self
.
text_encoder_key_prefix
[
0
]
if
"{}clip_l.transformer.text_model.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
if
"{}clip_l.transformer.text_model.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
clip_l
=
True
clip_l
=
True
if
"{}clip_g.transformer.text_model.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
if
"{}clip_g.transformer.text_model.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
clip_g
=
True
clip_g
=
True
if
"{}t5xxl.transformer.encoder.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
t5_key
=
"{}t5xxl.transformer.encoder.final_layer_norm.weight"
.
format
(
pref
)
if
t5_key
in
state_dict
:
t5
=
True
t5
=
True
dtype_t5
=
state_dict
[
t5_key
].
dtype
class
SD3ClipModel
(
sd3_clip
.
SD3ClipModel
):
class
SD3ClipModel
(
sd3_clip
.
SD3ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
(
clip_l
=
clip_l
,
clip_g
=
clip_g
,
t5
=
t5
,
device
=
device
,
dtype
=
dtype
)
super
().
__init__
(
clip_l
=
clip_l
,
clip_g
=
clip_g
,
t5
=
t5
,
dtype_t5
=
dtype_t5
,
device
=
device
,
dtype
=
dtype
)
return
supported_models_base
.
ClipTarget
(
sd3_clip
.
SD3Tokenizer
,
SD3ClipModel
)
return
supported_models_base
.
ClipTarget
(
sd3_clip
.
SD3Tokenizer
,
SD3ClipModel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment