Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ba07cb74
Commit
ba07cb74
authored
Dec 11, 2023
by
comfyanonymous
Browse files
Use faster manual cast for fp8 in unet.
parent
ab93abd4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
12 deletions
+48
-12
comfy/model_base.py
comfy/model_base.py
+10
-9
comfy/model_management.py
comfy/model_management.py
+15
-1
comfy/ops.py
comfy/ops.py
+9
-0
comfy/sd.py
comfy/sd.py
+10
-2
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-0
No files found.
comfy/model_base.py
View file @
ba07cb74
...
...
@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
import
comfy.model_management
import
comfy.conds
import
comfy.ops
from
enum
import
Enum
import
contextlib
from
.
import
utils
...
...
@@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
model_config
=
model_config
self
.
manual_cast_dtype
=
model_config
.
manual_cast_dtype
if
not
unet_config
.
get
(
"disable_unet_model_creation"
,
False
):
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
)
if
self
.
manual_cast_dtype
is
not
None
:
operations
=
comfy
.
ops
.
manual_cast
else
:
operations
=
comfy
.
ops
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
,
operations
=
operations
)
self
.
model_type
=
model_type
self
.
model_sampling
=
model_sampling
(
model_config
,
model_type
)
...
...
@@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module):
context
=
c_crossattn
dtype
=
self
.
get_dtype
()
if
comfy
.
model_management
.
supports_dtype
(
xc
.
device
,
dtype
):
precision_scope
=
lambda
a
:
contextlib
.
nullcontext
(
a
)
else
:
precision_scope
=
torch
.
autocast
dtype
=
torch
.
float32
if
self
.
manual_cast_dtype
is
not
None
:
dtype
=
self
.
manual_cast_dtype
xc
=
xc
.
to
(
dtype
)
t
=
self
.
model_sampling
.
timestep
(
t
).
float
()
...
...
@@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module):
extra
=
extra
.
to
(
dtype
)
extra_conds
[
o
]
=
extra
with
precision_scope
(
comfy
.
model_management
.
get_autocast_device
(
xc
.
device
)):
model_output
=
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
model_output
=
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
return
self
.
model_sampling
.
calculate_denoised
(
sigma
,
model_output
,
x
)
def
get_dtype
(
self
):
...
...
comfy/model_management.py
View file @
ba07cb74
...
...
@@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0):
return
torch
.
float16
return
torch
.
float32
# None means no manual cast
def
unet_manual_cast
(
weight_dtype
,
inference_device
):
if
weight_dtype
==
torch
.
float32
:
return
None
fp16_supported
=
comfy
.
model_management
.
should_use_fp16
(
inference_device
,
prioritize_performance
=
False
)
if
fp16_supported
and
weight_dtype
==
torch
.
float16
:
return
None
if
fp16_supported
:
return
torch
.
float16
else
:
return
torch
.
float32
def
text_encoder_offload_device
():
if
args
.
gpu_only
:
return
get_torch_device
()
...
...
@@ -538,7 +552,7 @@ def get_autocast_device(dev):
def
supports_dtype
(
device
,
dtype
):
#TODO
if
dtype
==
torch
.
float32
:
return
True
if
torch
.
device
(
"
cpu
"
)
==
device
:
if
is_
device
_
cpu
(
device
)
:
return
False
if
dtype
==
torch
.
float16
:
return
True
...
...
comfy/ops.py
View file @
ba07cb74
...
...
@@ -62,6 +62,15 @@ class manual_cast:
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
return
s
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
s
.
Conv3d
(
*
args
,
**
kwargs
)
else
:
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
@
contextmanager
def
use_comfy_ops
(
device
=
None
,
dtype
=
None
):
# Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear
=
torch
.
nn
.
Linear
...
...
comfy/sd.py
View file @
ba07cb74
...
...
@@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
,
"model.diffusion_model."
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
"model.diffusion_model."
,
unet_dtype
)
model_config
.
set_manual_cast
(
manual_cast_dtype
)
if
model_config
is
None
:
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
ckpt_path
))
...
...
@@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print
(
"left over keys:"
,
left_over
)
if
output_model
:
model_patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch
_device
()
,
offload_device
=
model_management
.
unet_offload_device
(),
current_device
=
inital_load_device
)
model_patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
load
_device
,
offload_device
=
model_management
.
unet_offload_device
(),
current_device
=
inital_load_device
)
if
inital_load_device
!=
torch
.
device
(
"cpu"
):
print
(
"loaded straight to GPU"
)
model_management
.
load_model_gpu
(
model_patcher
)
...
...
@@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def
load_unet_state_dict
(
sd
):
#load unet in diffusers format
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
if
"input_blocks.0.0.weight"
in
sd
:
#ldm
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
,
unet_dtype
)
if
model_config
is
None
:
...
...
@@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else
:
print
(
diffusers_keys
[
k
],
k
)
offload_device
=
model_management
.
unet_offload_device
()
model_config
.
set_manual_cast
(
manual_cast_dtype
)
model
=
model_config
.
get_model
(
new_sd
,
""
)
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
new_sd
,
""
)
left_over
=
sd
.
keys
()
if
len
(
left_over
)
>
0
:
print
(
"left over keys in unet:"
,
left_over
)
return
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch
_device
()
,
offload_device
=
offload_device
)
return
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
load
_device
,
offload_device
=
offload_device
)
def
load_unet
(
unet_path
):
sd
=
comfy
.
utils
.
load_torch_file
(
unet_path
)
...
...
comfy/supported_models_base.py
View file @
ba07cb74
...
...
@@ -22,6 +22,8 @@ class BASE:
sampling_settings
=
{}
latent_format
=
latent_formats
.
LatentFormat
manual_cast_dtype
=
None
@
classmethod
def
matches
(
s
,
unet_config
):
for
k
in
s
.
unet_config
:
...
...
@@ -71,3 +73,5 @@ class BASE:
replace_prefix
=
{
""
:
"first_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
set_manual_cast
(
self
,
manual_cast_dtype
):
self
.
manual_cast_dtype
=
manual_cast_dtype
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment