Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ae43f09e
Commit
ae43f09e
authored
Jun 15, 2023
by
comfyanonymous
Browse files
All the unet weights should now be initialized with the right dtype.
parent
cf3974c8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
23 deletions
+29
-23
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+6
-6
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+21
-15
comfy/ldm/modules/diffusionmodules/util.py
comfy/ldm/modules/diffusionmodules/util.py
+2
-2
No files found.
comfy/ldm/modules/attention.py
View file @
ae43f09e
...
...
@@ -51,9 +51,9 @@ def init_(tensor):
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
def
__init__
(
self
,
dim_in
,
dim_out
,
dtype
=
None
):
super
().
__init__
()
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
)
self
.
proj
=
comfy
.
ops
.
Linear
(
dim_in
,
dim_out
*
2
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
...
...
@@ -68,7 +68,7 @@ class FeedForward(nn.Module):
project_in
=
nn
.
Sequential
(
comfy
.
ops
.
Linear
(
dim
,
inner_dim
,
dtype
=
dtype
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
,
dtype
=
dtype
)
self
.
net
=
nn
.
Sequential
(
project_in
,
...
...
@@ -89,8 +89,8 @@ def zero_module(module):
return
module
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
def
Normalize
(
in_channels
,
dtype
=
None
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
,
dtype
=
dtype
)
class
SpatialSelfAttention
(
nn
.
Module
):
...
...
@@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
context_dim
=
[
context_dim
]
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
norm
=
Normalize
(
in_channels
,
dtype
=
dtype
)
if
not
use_linear
:
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
ae43f09e
...
...
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
output_shape
=
None
):
assert
x
.
shape
[
1
]
==
self
.
channels
...
...
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
,
dtype
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
...
...
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
,
dtype
=
dtype
)
else
:
assert
self
.
channels
==
self
.
out_channels
...
...
@@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
),
normalization
(
channels
,
dtype
=
dtype
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
)
...
...
@@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
,
dtype
=
dtype
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
...
...
@@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
dtype
=
dtype
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
),
normalization
(
self
.
out_channels
,
dtype
=
dtype
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
...
...
@@ -604,6 +604,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
)
]
ch
=
mult
*
model_channels
...
...
@@ -651,10 +652,11 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
dtype
=
self
.
dtype
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
)
)
...
...
@@ -679,6 +681,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
),
AttentionBlock
(
ch
,
...
...
@@ -698,6 +701,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
),
)
self
.
_feature_size
+=
ch
...
...
@@ -715,6 +719,7 @@ class UNetModel(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
)
]
ch
=
model_channels
*
mult
...
...
@@ -758,18 +763,19 @@ class UNetModel(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
dtype
=
self
.
dtype
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
normalization
(
ch
,
dtype
=
self
.
dtype
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
)),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)),
)
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
...
...
comfy/ldm/modules/diffusionmodules/util.py
View file @
ae43f09e
...
...
@@ -206,13 +206,13 @@ def mean_flat(tensor):
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
normalization
(
channels
):
def
normalization
(
channels
,
dtype
=
None
):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
32
,
channels
)
return
GroupNorm32
(
32
,
channels
,
dtype
=
dtype
)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment