Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4a8a839b
"...resnet50_tensorflow.git" did not exist on "e653807f13dbe6349e4202fb162016a0a0b339ef"
Commit
4a8a839b
authored
Nov 11, 2023
by
comfyanonymous
Browse files
Add option to use in place weight updating in ModelPatcher.
parent
412d3ff5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
5 deletions
+24
-5
comfy/model_patcher.py
comfy/model_patcher.py
+16
-5
comfy/utils.py
comfy/utils.py
+8
-0
No files found.
comfy/model_patcher.py
View file @
4a8a839b
...
@@ -6,7 +6,7 @@ import comfy.utils
...
@@ -6,7 +6,7 @@ import comfy.utils
import
comfy.model_management
import
comfy.model_management
class
ModelPatcher
:
class
ModelPatcher
:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
,
current_device
=
None
):
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
,
current_device
=
None
,
weight_inplace_update
=
False
):
self
.
size
=
size
self
.
size
=
size
self
.
model
=
model
self
.
model
=
model
self
.
patches
=
{}
self
.
patches
=
{}
...
@@ -22,6 +22,8 @@ class ModelPatcher:
...
@@ -22,6 +22,8 @@ class ModelPatcher:
else
:
else
:
self
.
current_device
=
current_device
self
.
current_device
=
current_device
self
.
weight_inplace_update
=
weight_inplace_update
def
model_size
(
self
):
def
model_size
(
self
):
if
self
.
size
>
0
:
if
self
.
size
>
0
:
return
self
.
size
return
self
.
size
...
@@ -171,15 +173,20 @@ class ModelPatcher:
...
@@ -171,15 +173,20 @@ class ModelPatcher:
weight
=
model_sd
[
key
]
weight
=
model_sd
[
key
]
inplace_update
=
self
.
weight_inplace_update
if
key
not
in
self
.
backup
:
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
to
(
self
.
offload_devic
e
)
self
.
backup
[
key
]
=
weight
.
to
(
device
=
device_to
,
copy
=
inplace_updat
e
)
if
device_to
is
not
None
:
if
device_to
is
not
None
:
temp_weight
=
comfy
.
model_management
.
cast_to_device
(
weight
,
device_to
,
torch
.
float32
,
copy
=
True
)
temp_weight
=
comfy
.
model_management
.
cast_to_device
(
weight
,
device_to
,
torch
.
float32
,
copy
=
True
)
else
:
else
:
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
comfy
.
utils
.
set_attr
(
self
.
model
,
key
,
out_weight
)
if
inplace_update
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
key
,
out_weight
)
else
:
comfy
.
utils
.
set_attr
(
self
.
model
,
key
,
out_weight
)
del
temp_weight
del
temp_weight
if
device_to
is
not
None
:
if
device_to
is
not
None
:
...
@@ -295,8 +302,12 @@ class ModelPatcher:
...
@@ -295,8 +302,12 @@ class ModelPatcher:
def
unpatch_model
(
self
,
device_to
=
None
):
def
unpatch_model
(
self
,
device_to
=
None
):
keys
=
list
(
self
.
backup
.
keys
())
keys
=
list
(
self
.
backup
.
keys
())
for
k
in
keys
:
if
self
.
weight_inplace_update
:
comfy
.
utils
.
set_attr
(
self
.
model
,
k
,
self
.
backup
[
k
])
for
k
in
keys
:
comfy
.
utils
.
copy_to_param
(
self
.
model
,
k
,
self
.
backup
[
k
])
else
:
for
k
in
keys
:
comfy
.
utils
.
set_attr
(
self
.
model
,
k
,
self
.
backup
[
k
])
self
.
backup
=
{}
self
.
backup
=
{}
...
...
comfy/utils.py
View file @
4a8a839b
...
@@ -261,6 +261,14 @@ def set_attr(obj, attr, value):
...
@@ -261,6 +261,14 @@ def set_attr(obj, attr, value):
setattr
(
obj
,
attrs
[
-
1
],
torch
.
nn
.
Parameter
(
value
))
setattr
(
obj
,
attrs
[
-
1
],
torch
.
nn
.
Parameter
(
value
))
del
prev
del
prev
def
copy_to_param
(
obj
,
attr
,
value
):
# inplace update tensor instead of replacing it
attrs
=
attr
.
split
(
"."
)
for
name
in
attrs
[:
-
1
]:
obj
=
getattr
(
obj
,
name
)
prev
=
getattr
(
obj
,
attrs
[
-
1
])
prev
.
data
.
copy_
(
value
)
def
get_attr
(
obj
,
attr
):
def
get_attr
(
obj
,
attr
):
attrs
=
attr
.
split
(
"."
)
attrs
=
attr
.
split
(
"."
)
for
name
in
attrs
:
for
name
in
attrs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment