Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b53924c7
Unverified
Commit
b53924c7
authored
Jun 08, 2022
by
Suraj Patil
Committed by
GitHub
Jun 08, 2022
Browse files
Merge pull request #6 from huggingface/add-ldm
add unet ldm in init
parents
ee71a3b6
4d53a521
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
2 deletions
+4
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+2
-2
No files found.
src/diffusers/__init__.py
View file @
b53924c7
...
@@ -7,5 +7,6 @@ __version__ = "0.0.1"
...
@@ -7,5 +7,6 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/models/__init__.py
View file @
b53924c7
...
@@ -18,3 +18,4 @@
...
@@ -18,3 +18,4 @@
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNetGLIDEModel
from
.unet_glide
import
UNetGLIDEModel
from
.unet_ldm
import
UNetLDMModel
src/diffusers/models/unet_ldm.py
View file @
b53924c7
...
@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
dtype
_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
num_heads_upsample
=
num_heads_upsample
...
@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
assert
y
.
shape
==
(
x
.
shape
[
0
],)
assert
y
.
shape
==
(
x
.
shape
[
0
],)
emb
=
emb
+
self
.
label_emb
(
y
)
emb
=
emb
+
self
.
label_emb
(
y
)
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
_
)
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
hs
.
append
(
h
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment