Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ba07cb74
"driver/driver.cpp" did not exist on "c64f63d5ec4f2f4b67ac1e2d1f719310c78a1951"
Commit
ba07cb74
authored
Dec 11, 2023
by
comfyanonymous
Browse files
Use faster manual cast for fp8 in unet.
parent
ab93abd4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
12 deletions
+48
-12
comfy/model_base.py
comfy/model_base.py
+10
-9
comfy/model_management.py
comfy/model_management.py
+15
-1
comfy/ops.py
comfy/ops.py
+9
-0
comfy/sd.py
comfy/sd.py
+10
-2
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-0
No files found.
comfy/model_base.py
View file @
ba07cb74
...
@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
...
@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
import
comfy.model_management
import
comfy.model_management
import
comfy.conds
import
comfy.conds
import
comfy.ops
from
enum
import
Enum
from
enum
import
Enum
import
contextlib
import
contextlib
from
.
import
utils
from
.
import
utils
...
@@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module):
...
@@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config
=
model_config
.
unet_config
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
latent_format
=
model_config
.
latent_format
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
manual_cast_dtype
=
model_config
.
manual_cast_dtype
if
not
unet_config
.
get
(
"disable_unet_model_creation"
,
False
):
if
not
unet_config
.
get
(
"disable_unet_model_creation"
,
False
):
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
)
if
self
.
manual_cast_dtype
is
not
None
:
operations
=
comfy
.
ops
.
manual_cast
else
:
operations
=
comfy
.
ops
self
.
diffusion_model
=
UNetModel
(
**
unet_config
,
device
=
device
,
operations
=
operations
)
self
.
model_type
=
model_type
self
.
model_type
=
model_type
self
.
model_sampling
=
model_sampling
(
model_config
,
model_type
)
self
.
model_sampling
=
model_sampling
(
model_config
,
model_type
)
...
@@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module):
...
@@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module):
context
=
c_crossattn
context
=
c_crossattn
dtype
=
self
.
get_dtype
()
dtype
=
self
.
get_dtype
()
if
comfy
.
model_management
.
supports_dtype
(
xc
.
device
,
dtype
):
if
self
.
manual_cast_dtype
is
not
None
:
precision_scope
=
lambda
a
:
contextlib
.
nullcontext
(
a
)
dtype
=
self
.
manual_cast_dtype
else
:
precision_scope
=
torch
.
autocast
dtype
=
torch
.
float32
xc
=
xc
.
to
(
dtype
)
xc
=
xc
.
to
(
dtype
)
t
=
self
.
model_sampling
.
timestep
(
t
).
float
()
t
=
self
.
model_sampling
.
timestep
(
t
).
float
()
...
@@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module):
...
@@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module):
extra
=
extra
.
to
(
dtype
)
extra
=
extra
.
to
(
dtype
)
extra_conds
[
o
]
=
extra
extra_conds
[
o
]
=
extra
with
precision_scope
(
comfy
.
model_management
.
get_autocast_device
(
xc
.
device
)):
model_output
=
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
model_output
=
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
return
self
.
model_sampling
.
calculate_denoised
(
sigma
,
model_output
,
x
)
return
self
.
model_sampling
.
calculate_denoised
(
sigma
,
model_output
,
x
)
def
get_dtype
(
self
):
def
get_dtype
(
self
):
...
...
comfy/model_management.py
View file @
ba07cb74
...
@@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0):
...
@@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0):
return
torch
.
float16
return
torch
.
float16
return
torch
.
float32
return
torch
.
float32
# None means no manual cast
def
unet_manual_cast
(
weight_dtype
,
inference_device
):
if
weight_dtype
==
torch
.
float32
:
return
None
fp16_supported
=
comfy
.
model_management
.
should_use_fp16
(
inference_device
,
prioritize_performance
=
False
)
if
fp16_supported
and
weight_dtype
==
torch
.
float16
:
return
None
if
fp16_supported
:
return
torch
.
float16
else
:
return
torch
.
float32
def
text_encoder_offload_device
():
def
text_encoder_offload_device
():
if
args
.
gpu_only
:
if
args
.
gpu_only
:
return
get_torch_device
()
return
get_torch_device
()
...
@@ -538,7 +552,7 @@ def get_autocast_device(dev):
...
@@ -538,7 +552,7 @@ def get_autocast_device(dev):
def
supports_dtype
(
device
,
dtype
):
#TODO
def
supports_dtype
(
device
,
dtype
):
#TODO
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
return
True
return
True
if
torch
.
device
(
"
cpu
"
)
==
device
:
if
is_
device
_
cpu
(
device
)
:
return
False
return
False
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
return
True
return
True
...
...
comfy/ops.py
View file @
ba07cb74
...
@@ -62,6 +62,15 @@ class manual_cast:
...
@@ -62,6 +62,15 @@ class manual_cast:
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
return
s
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
s
.
Conv3d
(
*
args
,
**
kwargs
)
else
:
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
@
contextmanager
@
contextmanager
def
use_comfy_ops
(
device
=
None
,
dtype
=
None
):
# Kind of an ugly hack but I can't think of a better way
def
use_comfy_ops
(
device
=
None
,
dtype
=
None
):
# Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear
=
torch
.
nn
.
Linear
old_torch_nn_linear
=
torch
.
nn
.
Linear
...
...
comfy/sd.py
View file @
ba07cb74
...
@@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
,
"model.diffusion_model."
)
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
,
"model.diffusion_model."
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
class
WeightsLoader
(
torch
.
nn
.
Module
):
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
pass
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
"model.diffusion_model."
,
unet_dtype
)
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
"model.diffusion_model."
,
unet_dtype
)
model_config
.
set_manual_cast
(
manual_cast_dtype
)
if
model_config
is
None
:
if
model_config
is
None
:
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
ckpt_path
))
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
ckpt_path
))
...
@@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print
(
"left over keys:"
,
left_over
)
print
(
"left over keys:"
,
left_over
)
if
output_model
:
if
output_model
:
model_patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch
_device
()
,
offload_device
=
model_management
.
unet_offload_device
(),
current_device
=
inital_load_device
)
model_patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
load
_device
,
offload_device
=
model_management
.
unet_offload_device
(),
current_device
=
inital_load_device
)
if
inital_load_device
!=
torch
.
device
(
"cpu"
):
if
inital_load_device
!=
torch
.
device
(
"cpu"
):
print
(
"loaded straight to GPU"
)
print
(
"loaded straight to GPU"
)
model_management
.
load_model_gpu
(
model_patcher
)
model_management
.
load_model_gpu
(
model_patcher
)
...
@@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def
load_unet_state_dict
(
sd
):
#load unet in diffusers format
def
load_unet_state_dict
(
sd
):
#load unet in diffusers format
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
)
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
if
"input_blocks.0.0.weight"
in
sd
:
#ldm
if
"input_blocks.0.0.weight"
in
sd
:
#ldm
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
,
unet_dtype
)
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
,
unet_dtype
)
if
model_config
is
None
:
if
model_config
is
None
:
...
@@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
...
@@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else
:
else
:
print
(
diffusers_keys
[
k
],
k
)
print
(
diffusers_keys
[
k
],
k
)
offload_device
=
model_management
.
unet_offload_device
()
offload_device
=
model_management
.
unet_offload_device
()
model_config
.
set_manual_cast
(
manual_cast_dtype
)
model
=
model_config
.
get_model
(
new_sd
,
""
)
model
=
model_config
.
get_model
(
new_sd
,
""
)
model
=
model
.
to
(
offload_device
)
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
new_sd
,
""
)
model
.
load_model_weights
(
new_sd
,
""
)
left_over
=
sd
.
keys
()
left_over
=
sd
.
keys
()
if
len
(
left_over
)
>
0
:
if
len
(
left_over
)
>
0
:
print
(
"left over keys in unet:"
,
left_over
)
print
(
"left over keys in unet:"
,
left_over
)
return
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch
_device
()
,
offload_device
=
offload_device
)
return
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
load
_device
,
offload_device
=
offload_device
)
def
load_unet
(
unet_path
):
def
load_unet
(
unet_path
):
sd
=
comfy
.
utils
.
load_torch_file
(
unet_path
)
sd
=
comfy
.
utils
.
load_torch_file
(
unet_path
)
...
...
comfy/supported_models_base.py
View file @
ba07cb74
...
@@ -22,6 +22,8 @@ class BASE:
...
@@ -22,6 +22,8 @@ class BASE:
sampling_settings
=
{}
sampling_settings
=
{}
latent_format
=
latent_formats
.
LatentFormat
latent_format
=
latent_formats
.
LatentFormat
manual_cast_dtype
=
None
@
classmethod
@
classmethod
def
matches
(
s
,
unet_config
):
def
matches
(
s
,
unet_config
):
for
k
in
s
.
unet_config
:
for
k
in
s
.
unet_config
:
...
@@ -71,3 +73,5 @@ class BASE:
...
@@ -71,3 +73,5 @@ class BASE:
replace_prefix
=
{
""
:
"first_stage_model."
}
replace_prefix
=
{
""
:
"first_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
set_manual_cast
(
self
,
manual_cast_dtype
):
self
.
manual_cast_dtype
=
manual_cast_dtype
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment