Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
57926635
Commit
57926635
authored
Dec 10, 2023
by
comfyanonymous
Browse files
Switch text encoder to manual cast.
Use fp16 text encoder weights for CPU inference to lower memory usage.
parent
69033081
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
29 deletions
+59
-29
comfy/model_management.py
comfy/model_management.py
+3
-0
comfy/ops.py
comfy/ops.py
+33
-0
comfy/sd1_clip.py
comfy/sd1_clip.py
+23
-29
No files found.
comfy/model_management.py
View file @
57926635
...
...
@@ -503,6 +503,9 @@ def text_encoder_dtype(device=None):
elif
args
.
fp32_text_enc
:
return
torch
.
float32
if
is_device_cpu
(
device
):
return
torch
.
float16
if
should_use_fp16
(
device
,
prioritize_performance
=
False
):
return
torch
.
float16
else
:
...
...
comfy/ops.py
View file @
57926635
...
...
@@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
else
:
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
cast_bias_weight
(
s
,
input
):
bias
=
None
if
s
.
bias
is
not
None
:
bias
=
s
.
bias
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
)
weight
=
s
.
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
)
return
weight
,
bias
class
manual_cast
:
class
Linear
(
Linear
):
def
forward
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
linear
(
input
,
weight
,
bias
)
class
Conv2d
(
Conv2d
):
def
forward
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
self
.
_conv_forward
(
input
,
weight
,
bias
)
class
Conv3d
(
Conv3d
):
def
forward
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
self
.
_conv_forward
(
input
,
weight
,
bias
)
class
GroupNorm
(
GroupNorm
):
def
forward
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
group_norm
(
input
,
self
.
num_groups
,
weight
,
bias
,
self
.
eps
)
class
LayerNorm
(
LayerNorm
):
def
forward
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
@
contextmanager
def
use_comfy_ops
(
device
=
None
,
dtype
=
None
):
# Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear
=
torch
.
nn
.
Linear
...
...
comfy/sd1_clip.py
View file @
57926635
...
...
@@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with
open
(
textmodel_json_config
)
as
f
:
config
=
json
.
load
(
f
)
self
.
transformer
=
model_class
(
config
,
dtype
,
device
,
comfy
.
ops
)
self
.
transformer
=
model_class
(
config
,
dtype
,
device
,
comfy
.
ops
.
manual_cast
)
self
.
num_layers
=
self
.
transformer
.
num_layers
self
.
max_length
=
max_length
...
...
@@ -160,12 +160,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens
=
self
.
set_up_textual_embeddings
(
tokens
,
backup_embeds
)
tokens
=
torch
.
LongTensor
(
tokens
).
to
(
device
)
if
self
.
transformer
.
dtype
!=
torch
.
float32
:
precision_scope
=
torch
.
autocast
else
:
precision_scope
=
lambda
a
,
dtype
:
contextlib
.
nullcontext
(
a
)
with
precision_scope
(
model_management
.
get_autocast_device
(
device
),
dtype
=
torch
.
float32
):
attention_mask
=
None
if
self
.
enable_attention_masks
:
attention_mask
=
torch
.
zeros_like
(
tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment