Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
dccca1da
Commit
dccca1da
authored
Feb 18, 2024
by
comfyanonymous
Browse files
Fix gligen lowvram mode.
parent
8b60d33b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
25 deletions
+27
-25
comfy/gligen.py
comfy/gligen.py
+27
-25
No files found.
comfy/gligen.py
View file @
dccca1da
...
@@ -2,7 +2,8 @@ import torch
...
@@ -2,7 +2,8 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
.ldm.modules.attention
import
CrossAttention
from
.ldm.modules.attention
import
CrossAttention
from
inspect
import
isfunction
from
inspect
import
isfunction
import
comfy.ops
ops
=
comfy
.
ops
.
manual_cast
def
exists
(
val
):
def
exists
(
val
):
return
val
is
not
None
return
val
is
not
None
...
@@ -22,7 +23,7 @@ def default(val, d):
...
@@ -22,7 +23,7 @@ def default(val, d):
class
GEGLU
(
nn
.
Module
):
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
self
.
proj
=
ops
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
...
@@ -35,14 +36,14 @@ class FeedForward(nn.Module):
...
@@ -35,14 +36,14 @@ class FeedForward(nn.Module):
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
ops
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
()
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
project_in
,
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
ops
.
Linear
(
inner_dim
,
dim_out
)
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
...
@@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
query_dim
=
query_dim
,
query_dim
=
query_dim
,
context_dim
=
context_dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
heads
=
n_heads
,
dim_head
=
d_head
)
dim_head
=
d_head
,
operations
=
ops
)
self
.
ff
=
FeedForward
(
query_dim
,
glu
=
True
)
self
.
ff
=
FeedForward
(
query_dim
,
glu
=
True
)
self
.
norm1
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm1
=
ops
.
LayerNorm
(
query_dim
)
self
.
norm2
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm2
=
ops
.
LayerNorm
(
query_dim
)
self
.
register_parameter
(
'alpha_attn'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_attn'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_dense'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_dense'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
...
@@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
...
@@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# we need a linear projection since we need cat visual feature and obj
# feature
# feature
self
.
linear
=
nn
.
Linear
(
context_dim
,
query_dim
)
self
.
linear
=
ops
.
Linear
(
context_dim
,
query_dim
)
self
.
attn
=
CrossAttention
(
self
.
attn
=
CrossAttention
(
query_dim
=
query_dim
,
query_dim
=
query_dim
,
context_dim
=
query_dim
,
context_dim
=
query_dim
,
heads
=
n_heads
,
heads
=
n_heads
,
dim_head
=
d_head
)
dim_head
=
d_head
,
operations
=
ops
)
self
.
ff
=
FeedForward
(
query_dim
,
glu
=
True
)
self
.
ff
=
FeedForward
(
query_dim
,
glu
=
True
)
self
.
norm1
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm1
=
ops
.
LayerNorm
(
query_dim
)
self
.
norm2
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm2
=
ops
.
LayerNorm
(
query_dim
)
self
.
register_parameter
(
'alpha_attn'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_attn'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_dense'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_dense'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
...
@@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
...
@@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# we need a linear projection since we need cat visual feature and obj
# feature
# feature
self
.
linear
=
nn
.
Linear
(
context_dim
,
query_dim
)
self
.
linear
=
ops
.
Linear
(
context_dim
,
query_dim
)
self
.
attn
=
CrossAttention
(
self
.
attn
=
CrossAttention
(
query_dim
=
query_dim
,
context_dim
=
query_dim
,
dim_head
=
d_head
)
query_dim
=
query_dim
,
context_dim
=
query_dim
,
dim_head
=
d_head
,
operations
=
ops
)
self
.
ff
=
FeedForward
(
query_dim
,
glu
=
True
)
self
.
ff
=
FeedForward
(
query_dim
,
glu
=
True
)
self
.
norm1
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm1
=
ops
.
LayerNorm
(
query_dim
)
self
.
norm2
=
nn
.
LayerNorm
(
query_dim
)
self
.
norm2
=
ops
.
LayerNorm
(
query_dim
)
self
.
register_parameter
(
'alpha_attn'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_attn'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_dense'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
self
.
register_parameter
(
'alpha_dense'
,
nn
.
Parameter
(
torch
.
tensor
(
0.
)))
...
@@ -201,11 +204,11 @@ class PositionNet(nn.Module):
...
@@ -201,11 +204,11 @@ class PositionNet(nn.Module):
self
.
position_dim
=
fourier_freqs
*
2
*
4
# 2 is sin&cos, 4 is xyxy
self
.
position_dim
=
fourier_freqs
*
2
*
4
# 2 is sin&cos, 4 is xyxy
self
.
linears
=
nn
.
Sequential
(
self
.
linears
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
in_dim
+
self
.
position_dim
,
512
),
ops
.
Linear
(
self
.
in_dim
+
self
.
position_dim
,
512
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Linear
(
512
,
512
),
ops
.
Linear
(
512
,
512
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Linear
(
512
,
out_dim
),
ops
.
Linear
(
512
,
out_dim
),
)
)
self
.
null_positive_feature
=
torch
.
nn
.
Parameter
(
self
.
null_positive_feature
=
torch
.
nn
.
Parameter
(
...
@@ -215,16 +218,15 @@ class PositionNet(nn.Module):
...
@@ -215,16 +218,15 @@ class PositionNet(nn.Module):
def
forward
(
self
,
boxes
,
masks
,
positive_embeddings
):
def
forward
(
self
,
boxes
,
masks
,
positive_embeddings
):
B
,
N
,
_
=
boxes
.
shape
B
,
N
,
_
=
boxes
.
shape
dtype
=
self
.
linears
[
0
].
weight
.
dtype
masks
=
masks
.
unsqueeze
(
-
1
)
masks
=
masks
.
unsqueeze
(
-
1
).
to
(
dtype
)
positive_embeddings
=
positive_embeddings
positive_embeddings
=
positive_embeddings
.
to
(
dtype
)
# embedding position (it may includes padding as placeholder)
# embedding position (it may includes padding as placeholder)
xyxy_embedding
=
self
.
fourier_embedder
(
boxes
.
to
(
dtype
)
)
# B*N*4 --> B*N*C
xyxy_embedding
=
self
.
fourier_embedder
(
boxes
)
# B*N*4 --> B*N*C
# learnable null embedding
# learnable null embedding
positive_null
=
self
.
null_positive_feature
.
view
(
1
,
1
,
-
1
)
positive_null
=
self
.
null_positive_feature
.
to
(
device
=
boxes
.
device
,
dtype
=
boxes
.
dtype
).
view
(
1
,
1
,
-
1
)
xyxy_null
=
self
.
null_position_feature
.
view
(
1
,
1
,
-
1
)
xyxy_null
=
self
.
null_position_feature
.
to
(
device
=
boxes
.
device
,
dtype
=
boxes
.
dtype
).
view
(
1
,
1
,
-
1
)
# replace padding with learnable null embedding
# replace padding with learnable null embedding
positive_embeddings
=
positive_embeddings
*
\
positive_embeddings
=
positive_embeddings
*
\
...
@@ -251,7 +253,7 @@ class Gligen(nn.Module):
...
@@ -251,7 +253,7 @@ class Gligen(nn.Module):
def
func
(
x
,
extra_options
):
def
func
(
x
,
extra_options
):
key
=
extra_options
[
"transformer_index"
]
key
=
extra_options
[
"transformer_index"
]
module
=
self
.
module_list
[
key
]
module
=
self
.
module_list
[
key
]
return
module
(
x
,
objs
)
return
module
(
x
,
objs
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
)
return
func
return
func
def
set_position
(
self
,
latent_image_shape
,
position_params
,
device
):
def
set_position
(
self
,
latent_image_shape
,
position_params
,
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment