Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ae43f09e
"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "2f9d6a97ec7e3cb25beb13a320da8ec4573355d3"
Commit
ae43f09e
authored
Jun 15, 2023
by
comfyanonymous
Browse files
All the unet weights should now be initialized with the right dtype.
parent
cf3974c8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
23 deletions
+29
-23
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+6
-6
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+21
-15
comfy/ldm/modules/diffusionmodules/util.py
comfy/ldm/modules/diffusionmodules/util.py
+2
-2
No files found.
comfy/ldm/modules/attention.py
View file @
ae43f09e
...
@@ -51,9 +51,9 @@ def init_(tensor):
...
@@ -51,9 +51,9 @@ def init_(tensor):
# feedforward
# feedforward
class
GEGLU
(
nn
.
Module
):
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
)
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
...
@@ -68,7 +68,7 @@ class FeedForward(nn.Module):
...
@@ -68,7 +68,7 @@ class FeedForward(nn.Module):
project_in
=
nn
.
Sequential
(
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
nn
.
GELU
()
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
)
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
project_in
,
project_in
,
...
@@ -89,8 +89,8 @@ def zero_module(module):
...
@@ -89,8 +89,8 @@ def zero_module(module):
return
module
return
module
def
Normalize
(
in_channels
):
def
Normalize
(
in_channels
,
dtype
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
)
class
SpatialSelfAttention
(
nn
.
Module
):
class
SpatialSelfAttention
(
nn
.
Module
):
...
@@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
...
@@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
context_dim
=
[
context_dim
]
context_dim
=
[
context_dim
]
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
)
if
not
use_linear
:
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
inner_dim
,
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
ae43f09e
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
dims
=
dims
if
use_conv
:
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
output_shape
=
None
):
def
forward
(
self
,
x
,
output_shape
=
None
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
if
use_conv
:
self
.
op
=
conv_nd
(
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
)
)
else
:
else
:
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
...
@@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
...
@@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
normalization
(
channels
,
dtype
=
dtype
),
nn
.
SiLU
(),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
)
)
...
@@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
...
@@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
else
:
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
@@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
...
@@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
nn
.
SiLU
(),
linear
(
linear
(
emb_channels
,
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
),
),
)
)
self
.
out_layers
=
nn
.
Sequential
(
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
normalization
(
self
.
out_channels
,
dtype
=
dtype
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
zero_module
(
...
@@ -604,6 +604,7 @@ class UNetModel(nn.Module):
...
@@ -604,6 +604,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
...
@@ -651,10 +652,11 @@ class UNetModel(nn.Module):
...
@@ -651,10 +652,11 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
down
=
True
,
dtype
=
self
.
dtype
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
)
)
)
)
)
...
@@ -679,6 +681,7 @@ class UNetModel(nn.Module):
...
@@ -679,6 +681,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
),
),
AttentionBlock
(
AttentionBlock
(
ch
,
ch
,
...
@@ -698,6 +701,7 @@ class UNetModel(nn.Module):
...
@@ -698,6 +701,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
),
),
)
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
@@ -715,6 +719,7 @@ class UNetModel(nn.Module):
...
@@ -715,6 +719,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
)
)
]
]
ch
=
model_channels
*
mult
ch
=
model_channels
*
mult
...
@@ -758,18 +763,19 @@ class UNetModel(nn.Module):
...
@@ -758,18 +763,19 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
up
=
True
,
dtype
=
self
.
dtype
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
normalization
(
ch
,
dtype
=
self
.
dtype
),
nn
.
SiLU
(),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
)),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)),
)
)
if
self
.
predict_codebook_ids
:
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
self
.
id_predictor
=
nn
.
Sequential
(
...
...
comfy/ldm/modules/diffusionmodules/util.py
View file @
ae43f09e
...
@@ -206,13 +206,13 @@ def mean_flat(tensor):
...
@@ -206,13 +206,13 @@ def mean_flat(tensor):
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
normalization
(
channels
):
def
normalization
(
channels
,
dtype
=
None
):
"""
"""
Make a standard normalization layer.
Make a standard normalization layer.
:param channels: number of input channels.
:param channels: number of input channels.
:return: an nn.Module for normalization.
:return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
32
,
channels
)
return
GroupNorm32
(
32
,
channels
,
dtype
=
dtype
)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment