Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
88b015dc
Unverified
Commit
88b015dc
authored
Dec 18, 2024
by
Xinyuan Zhao
Committed by
GitHub
Dec 17, 2024
Browse files
Make `time_embed_dim` of `UNet2DModel` changeable (#10262)
parent
63cdf9c0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
1 deletion
+2
-1
src/diffusers/models/unets/unet_2d.py
src/diffusers/models/unets/unet_2d.py
+2
-1
No files found.
src/diffusers/models/unets/unet_2d.py
View file @
88b015dc
...
...
@@ -97,6 +97,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
out_channels
:
int
=
3
,
center_input_sample
:
bool
=
False
,
time_embedding_type
:
str
=
"positional"
,
time_embedding_dim
:
Optional
[
int
]
=
None
,
freq_shift
:
int
=
0
,
flip_sin_to_cos
:
bool
=
True
,
down_block_types
:
Tuple
[
str
,
...]
=
(
"DownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
,
"AttnDownBlock2D"
),
...
...
@@ -122,7 +123,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
super
().
__init__
()
self
.
sample_size
=
sample_size
time_embed_dim
=
block_out_channels
[
0
]
*
4
time_embed_dim
=
time_embedding_dim
or
block_out_channels
[
0
]
*
4
# Check inputs
if
len
(
down_block_types
)
!=
len
(
up_block_types
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment