Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
554b374d
Commit
554b374d
authored
Nov 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d5ab55e4
a0520193
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
756 additions
and
122 deletions
+756
-122
src/diffusers/experimental/__init__.py
src/diffusers/experimental/__init__.py
+1
-0
src/diffusers/experimental/rl/__init__.py
src/diffusers/experimental/rl/__init__.py
+1
-0
src/diffusers/experimental/rl/value_guided_sampling.py
src/diffusers/experimental/rl/value_guided_sampling.py
+129
-0
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+10
-3
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+136
-2
src/diffusers/models/unet_1d.py
src/diffusers/models/unet_1d.py
+94
-21
src/diffusers/models/unet_1d_blocks.py
src/diffusers/models/unet_1d_blocks.py
+315
-31
src/diffusers/models/unet_2d.py
src/diffusers/models/unet_2d.py
+1
-1
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+1
-1
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+5
-5
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+9
-9
src/diffusers/pipelines/README.md
src/diffusers/pipelines/README.md
+1
-1
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+1
-1
src/diffusers/pipelines/stable_diffusion/README.md
src/diffusers/pipelines/stable_diffusion/README.md
+3
-3
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+4
-11
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+10
-3
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+10
-13
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+11
-4
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+4
-10
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+10
-3
No files found.
src/diffusers/experimental/__init__.py
0 → 100644
View file @
554b374d
from
.rl
import
ValueGuidedRLPipeline
src/diffusers/experimental/rl/__init__.py
0 → 100644
View file @
554b374d
from
.value_guided_sampling
import
ValueGuidedRLPipeline
src/diffusers/experimental/rl/value_guided_sampling.py
0 → 100644
View file @
554b374d
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
torch
import
tqdm
from
...models.unet_1d
import
UNet1DModel
from
...pipeline_utils
import
DiffusionPipeline
from
...utils.dummy_pt_objects
import
DDPMScheduler
class
ValueGuidedRLPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
value_function
:
UNet1DModel
,
unet
:
UNet1DModel
,
scheduler
:
DDPMScheduler
,
env
,
):
super
().
__init__
()
self
.
value_function
=
value_function
self
.
unet
=
unet
self
.
scheduler
=
scheduler
self
.
env
=
env
self
.
data
=
env
.
get_dataset
()
self
.
means
=
dict
()
for
key
in
self
.
data
.
keys
():
try
:
self
.
means
[
key
]
=
self
.
data
[
key
].
mean
()
except
:
pass
self
.
stds
=
dict
()
for
key
in
self
.
data
.
keys
():
try
:
self
.
stds
[
key
]
=
self
.
data
[
key
].
std
()
except
:
pass
self
.
state_dim
=
env
.
observation_space
.
shape
[
0
]
self
.
action_dim
=
env
.
action_space
.
shape
[
0
]
def
normalize
(
self
,
x_in
,
key
):
return
(
x_in
-
self
.
means
[
key
])
/
self
.
stds
[
key
]
def
de_normalize
(
self
,
x_in
,
key
):
return
x_in
*
self
.
stds
[
key
]
+
self
.
means
[
key
]
def
to_torch
(
self
,
x_in
):
if
type
(
x_in
)
is
dict
:
return
{
k
:
self
.
to_torch
(
v
)
for
k
,
v
in
x_in
.
items
()}
elif
torch
.
is_tensor
(
x_in
):
return
x_in
.
to
(
self
.
unet
.
device
)
return
torch
.
tensor
(
x_in
,
device
=
self
.
unet
.
device
)
def
reset_x0
(
self
,
x_in
,
cond
,
act_dim
):
for
key
,
val
in
cond
.
items
():
x_in
[:,
key
,
act_dim
:]
=
val
.
clone
()
return
x_in
def
run_diffusion
(
self
,
x
,
conditions
,
n_guide_steps
,
scale
):
batch_size
=
x
.
shape
[
0
]
y
=
None
for
i
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# create batch of timesteps to pass into model
timesteps
=
torch
.
full
((
batch_size
,),
i
,
device
=
self
.
unet
.
device
,
dtype
=
torch
.
long
)
for
_
in
range
(
n_guide_steps
):
with
torch
.
enable_grad
():
x
.
requires_grad_
()
y
=
self
.
value_function
(
x
.
permute
(
0
,
2
,
1
),
timesteps
).
sample
grad
=
torch
.
autograd
.
grad
([
y
.
sum
()],
[
x
])[
0
]
posterior_variance
=
self
.
scheduler
.
_get_variance
(
i
)
model_std
=
torch
.
exp
(
0.5
*
posterior_variance
)
grad
=
model_std
*
grad
grad
[
timesteps
<
2
]
=
0
x
=
x
.
detach
()
x
=
x
+
scale
*
grad
x
=
self
.
reset_x0
(
x
,
conditions
,
self
.
action_dim
)
prev_x
=
self
.
unet
(
x
.
permute
(
0
,
2
,
1
),
timesteps
).
sample
.
permute
(
0
,
2
,
1
)
x
=
self
.
scheduler
.
step
(
prev_x
,
i
,
x
,
predict_epsilon
=
False
)[
"prev_sample"
]
# apply conditions to the trajectory
x
=
self
.
reset_x0
(
x
,
conditions
,
self
.
action_dim
)
x
=
self
.
to_torch
(
x
)
return
x
,
y
def
__call__
(
self
,
obs
,
batch_size
=
64
,
planning_horizon
=
32
,
n_guide_steps
=
2
,
scale
=
0.1
):
# normalize the observations and create batch dimension
obs
=
self
.
normalize
(
obs
,
"observations"
)
obs
=
obs
[
None
].
repeat
(
batch_size
,
axis
=
0
)
conditions
=
{
0
:
self
.
to_torch
(
obs
)}
shape
=
(
batch_size
,
planning_horizon
,
self
.
state_dim
+
self
.
action_dim
)
# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1
=
torch
.
randn
(
shape
,
device
=
self
.
unet
.
device
)
x
=
self
.
reset_x0
(
x1
,
conditions
,
self
.
action_dim
)
x
=
self
.
to_torch
(
x
)
# run the diffusion process
x
,
y
=
self
.
run_diffusion
(
x
,
conditions
,
n_guide_steps
,
scale
)
# sort output trajectories by value
sorted_idx
=
y
.
argsort
(
0
,
descending
=
True
).
squeeze
()
sorted_values
=
x
[
sorted_idx
]
actions
=
sorted_values
[:,
:,
:
self
.
action_dim
]
actions
=
actions
.
detach
().
cpu
().
numpy
()
denorm_actions
=
self
.
de_normalize
(
actions
,
key
=
"actions"
)
# select the action with the highest value
if
y
is
not
None
:
selected_index
=
0
else
:
# if we didn't run value guiding, select a random action
selected_index
=
np
.
random
.
randint
(
0
,
batch_size
)
denorm_actions
=
denorm_actions
[
selected_index
,
0
]
return
denorm_actions
src/diffusers/models/embeddings.py
View file @
554b374d
...
@@ -62,14 +62,21 @@ def get_timestep_embedding(
...
@@ -62,14 +62,21 @@ def get_timestep_embedding(
class
TimestepEmbedding
(
nn
.
Module
):
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
:
int
,
time_embed_dim
:
int
,
act_fn
:
str
=
"silu"
):
def
__init__
(
self
,
in_
channel
s
:
int
,
time_embed_dim
:
int
,
act_fn
:
str
=
"silu"
,
out_dim
:
int
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
channel
,
time_embed_dim
)
self
.
linear_1
=
nn
.
Linear
(
in_
channel
s
,
time_embed_dim
)
self
.
act
=
None
self
.
act
=
None
if
act_fn
==
"silu"
:
if
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
self
.
act
=
nn
.
SiLU
()
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
)
elif
act_fn
==
"mish"
:
self
.
act
=
nn
.
Mish
()
if
out_dim
is
not
None
:
time_embed_dim_out
=
out_dim
else
:
time_embed_dim_out
=
time_embed_dim
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim_out
)
def
forward
(
self
,
sample
):
def
forward
(
self
,
sample
):
sample
=
self
.
linear_1
(
sample
)
sample
=
self
.
linear_1
(
sample
)
...
...
src/diffusers/models/resnet.py
View file @
554b374d
...
@@ -5,6 +5,75 @@ import torch.nn as nn
...
@@ -5,6 +5,75 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
class
Upsample1D
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
self
.
conv
=
None
if
use_conv_transpose
:
self
.
conv
=
nn
.
ConvTranspose1d
(
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
nn
.
Conv1d
(
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample1D
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
2
self
.
name
=
name
if
use_conv
:
self
.
conv
=
nn
.
Conv1d
(
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
conv
=
nn
.
AvgPool1d
(
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
conv
(
x
)
class
Upsample2D
(
nn
.
Module
):
class
Upsample2D
(
nn
.
Module
):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
...
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
...
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
Parameters:
Parameters:
channels: channels in the inputs and outputs.
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
use_conv_transpose:
out_channels:
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
...
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
...
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
Parameters:
Parameters:
channels: channels in the inputs and outputs.
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
out_channels:
padding:
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
...
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
...
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
# unet_rl.py
def
rearrange_dims
(
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
class
Conv1dBlock
(
nn
.
Module
):
"""
Conv1d --> GroupNorm --> Mish
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
conv1d
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
group_norm
=
nn
.
GroupNorm
(
n_groups
,
out_channels
)
self
.
mish
=
nn
.
Mish
()
def
forward
(
self
,
x
):
x
=
self
.
conv1d
(
x
)
x
=
rearrange_dims
(
x
)
x
=
self
.
group_norm
(
x
)
x
=
rearrange_dims
(
x
)
x
=
self
.
mish
(
x
)
return
x
# unet_rl.py
class
ResidualTemporalBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
kernel_size
=
5
):
super
().
__init__
()
self
.
conv_in
=
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
)
self
.
conv_out
=
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
)
self
.
time_emb_act
=
nn
.
Mish
()
self
.
time_emb
=
nn
.
Linear
(
embed_dim
,
out_channels
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
Args:
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
t
=
self
.
time_emb_act
(
t
)
t
=
self
.
time_emb
(
t
)
out
=
self
.
conv_in
(
x
)
+
rearrange_dims
(
t
)
out
=
self
.
conv_out
(
out
)
return
out
+
self
.
residual_conv
(
x
)
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample2D a batch of 2D images with the given filter.
r
"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
...
...
src/diffusers/models/unet_1d.py
View file @
554b374d
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
...
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
...
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
..utils
import
BaseOutput
from
..utils
import
BaseOutput
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.unet_1d_blocks
import
get_down_block
,
get_mid_block
,
get_up_block
from
.unet_1d_blocks
import
get_down_block
,
get_mid_block
,
get_out_block
,
get_up_block
@
dataclass
@
dataclass
...
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
implements for all the model (such as downloading or saving, etc.)
Parameters:
Parameters:
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
sample_size (`int`, *option
a
l*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`
in
t`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
freq_shift (`
floa
t`, *optional*, defaults to
0.
0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
down_block_types (`Tuple[str]`, *optional*, defaults to :
...
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
obj:`(32, 32, 64)`): Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
downsample_each_block (`int`, *optional*, defaults to False:
experimental feature for using a UNet without upsampling.
"""
"""
@
register_to_config
@
register_to_config
...
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels
:
int
=
2
,
out_channels
:
int
=
2
,
extra_in_channels
:
int
=
0
,
extra_in_channels
:
int
=
0
,
time_embedding_type
:
str
=
"fourier"
,
time_embedding_type
:
str
=
"fourier"
,
freq_shift
:
int
=
0
,
flip_sin_to_cos
:
bool
=
True
,
flip_sin_to_cos
:
bool
=
True
,
use_timestep_embedding
:
bool
=
False
,
use_timestep_embedding
:
bool
=
False
,
freq_shift
:
float
=
0.0
,
down_block_types
:
Tuple
[
str
]
=
(
"DownBlock1DNoSkip"
,
"DownBlock1D"
,
"AttnDownBlock1D"
),
down_block_types
:
Tuple
[
str
]
=
(
"DownBlock1DNoSkip"
,
"DownBlock1D"
,
"AttnDownBlock1D"
),
mid_block_type
:
str
=
"UNetMidBlock1D"
,
up_block_types
:
Tuple
[
str
]
=
(
"AttnUpBlock1D"
,
"UpBlock1D"
,
"UpBlock1DNoSkip"
),
up_block_types
:
Tuple
[
str
]
=
(
"AttnUpBlock1D"
,
"UpBlock1D"
,
"UpBlock1DNoSkip"
),
mid_block_type
:
Tuple
[
str
]
=
"UNetMidBlock1D"
,
out_block_type
:
str
=
None
,
block_out_channels
:
Tuple
[
int
]
=
(
32
,
32
,
64
),
block_out_channels
:
Tuple
[
int
]
=
(
32
,
32
,
64
),
act_fn
:
str
=
None
,
norm_num_groups
:
int
=
8
,
layers_per_block
:
int
=
1
,
downsample_each_block
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
sample_size
=
sample_size
self
.
sample_size
=
sample_size
# time
# time
...
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
)
)
timestep_input_dim
=
2
*
block_out_channels
[
0
]
timestep_input_dim
=
2
*
block_out_channels
[
0
]
elif
time_embedding_type
==
"positional"
:
elif
time_embedding_type
==
"positional"
:
self
.
time_proj
=
Timesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
,
freq_shift
)
self
.
time_proj
=
Timesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
freq_shift
)
timestep_input_dim
=
block_out_channels
[
0
]
timestep_input_dim
=
block_out_channels
[
0
]
if
use_timestep_embedding
:
if
use_timestep_embedding
:
time_embed_dim
=
block_out_channels
[
0
]
*
4
time_embed_dim
=
block_out_channels
[
0
]
*
4
self
.
time_embedding
=
TimestepEmbedding
(
timestep_input_dim
,
time_embed_dim
)
self
.
time_mlp
=
TimestepEmbedding
(
in_channels
=
timestep_input_dim
,
time_embed_dim
=
time_embed_dim
,
act_fn
=
act_fn
,
out_dim
=
block_out_channels
[
0
],
)
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
mid_block
=
None
self
.
mid_block
=
None
...
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
if
i
==
0
:
if
i
==
0
:
input_channel
+=
extra_in_channels
input_channel
+=
extra_in_channels
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
down_block
=
get_down_block
(
down_block
=
get_down_block
(
down_block_type
,
down_block_type
,
num_layers
=
layers_per_block
,
in_channels
=
input_channel
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
out_channels
=
output_channel
,
temb_channels
=
block_out_channels
[
0
],
add_downsample
=
not
is_final_block
or
downsample_each_block
,
)
)
self
.
down_blocks
.
append
(
down_block
)
self
.
down_blocks
.
append
(
down_block
)
# mid
# mid
self
.
mid_block
=
get_mid_block
(
self
.
mid_block
=
get_mid_block
(
mid_block_type
=
mid_block_type
,
mid_block_type
,
mid_channels
=
block_out_channels
[
-
1
],
in_channels
=
block_out_channels
[
-
1
],
in_channels
=
block_out_channels
[
-
1
],
out_channels
=
None
,
mid_channels
=
block_out_channels
[
-
1
],
out_channels
=
block_out_channels
[
-
1
],
embed_dim
=
block_out_channels
[
0
],
num_layers
=
layers_per_block
,
add_downsample
=
downsample_each_block
,
)
)
# up
# up
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
output_channel
=
reversed_block_out_channels
[
0
]
output_channel
=
reversed_block_out_channels
[
0
]
if
out_block_type
is
None
:
final_upsample_channels
=
out_channels
else
:
final_upsample_channels
=
block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
prev_output_channel
=
output_channel
prev_output_channel
=
output_channel
output_channel
=
reversed_block_out_channels
[
i
+
1
]
if
i
<
len
(
up_block_types
)
-
1
else
out_channels
output_channel
=
(
reversed_block_out_channels
[
i
+
1
]
if
i
<
len
(
up_block_types
)
-
1
else
final_upsample_channels
)
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
up_block
=
get_up_block
(
up_block
=
get_up_block
(
up_block_type
,
up_block_type
,
num_layers
=
layers_per_block
,
in_channels
=
prev_output_channel
,
in_channels
=
prev_output_channel
,
out_channels
=
output_channel
,
out_channels
=
output_channel
,
temb_channels
=
block_out_channels
[
0
],
add_upsample
=
not
is_final_block
,
)
)
self
.
up_blocks
.
append
(
up_block
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
prev_output_channel
=
output_channel
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
# out
# Totally fine to add another layer with a if statement - no need for nn.Identity here
num_groups_out
=
norm_num_groups
if
norm_num_groups
is
not
None
else
min
(
block_out_channels
[
0
]
//
4
,
32
)
self
.
out_block
=
get_out_block
(
out_block_type
=
out_block_type
,
num_groups_out
=
num_groups_out
,
embed_dim
=
block_out_channels
[
0
],
out_channels
=
out_channels
,
act_fn
=
act_fn
,
fc_dim
=
block_out_channels
[
-
1
]
//
4
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
"""
# 1. time
if
len
(
timestep
.
shape
)
==
0
:
timestep
=
timestep
[
None
]
timestep_embed
=
self
.
time_proj
(
timestep
)[...,
None
]
# 1. time
timestep_embed
=
timestep_embed
.
repeat
([
1
,
1
,
sample
.
shape
[
2
]]).
to
(
sample
.
dtype
)
timesteps
=
timestep
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
timestep_embed
=
self
.
time_proj
(
timesteps
)
if
self
.
config
.
use_timestep_embedding
:
timestep_embed
=
self
.
time_mlp
(
timestep_embed
)
else
:
timestep_embed
=
timestep_embed
[...,
None
]
timestep_embed
=
timestep_embed
.
repeat
([
1
,
1
,
sample
.
shape
[
2
]]).
to
(
sample
.
dtype
)
# 2. down
# 2. down
down_block_res_samples
=
()
down_block_res_samples
=
()
...
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
# 3. mid
# 3. mid
sample
=
self
.
mid_block
(
sample
)
if
self
.
mid_block
:
sample
=
self
.
mid_block
(
sample
,
timestep_embed
)
# 4. up
# 4. up
for
i
,
upsample_block
in
enumerate
(
self
.
up_blocks
):
for
i
,
upsample_block
in
enumerate
(
self
.
up_blocks
):
res_samples
=
down_block_res_samples
[
-
1
:]
res_samples
=
down_block_res_samples
[
-
1
:]
down_block_res_samples
=
down_block_res_samples
[:
-
1
]
down_block_res_samples
=
down_block_res_samples
[:
-
1
]
sample
=
upsample_block
(
sample
,
res_samples
)
sample
=
upsample_block
(
sample
,
res_hidden_states_tuple
=
res_samples
,
temb
=
timestep_embed
)
# 5. post-process
if
self
.
out_block
:
sample
=
self
.
out_block
(
sample
,
timestep_embed
)
if
not
return_dict
:
if
not
return_dict
:
return
(
sample
,)
return
(
sample
,)
...
...
src/diffusers/models/unet_1d_blocks.py
View file @
554b374d
...
@@ -17,6 +17,256 @@ import torch
...
@@ -17,6 +17,256 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
.resnet
import
Downsample1D
,
ResidualTemporalBlock1D
,
Upsample1D
,
rearrange_dims
class
DownResnetBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
num_layers
=
1
,
conv_shortcut
=
False
,
temb_channels
=
32
,
groups
=
32
,
groups_out
=
None
,
non_linearity
=
None
,
time_embedding_norm
=
"default"
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
add_downsample
=
add_downsample
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
# there will always be at least one resnet
resnets
=
[
ResidualTemporalBlock1D
(
in_channels
,
out_channels
,
embed_dim
=
temb_channels
)]
for
_
in
range
(
num_layers
):
resnets
.
append
(
ResidualTemporalBlock1D
(
out_channels
,
out_channels
,
embed_dim
=
temb_channels
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
self
.
nonlinearity
=
None
self
.
downsample
=
None
if
add_downsample
:
self
.
downsample
=
Downsample1D
(
out_channels
,
use_conv
=
True
,
padding
=
1
)
def
forward
(
self
,
hidden_states
,
temb
=
None
):
output_states
=
()
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
resnet
in
self
.
resnets
[
1
:]:
hidden_states
=
resnet
(
hidden_states
,
temb
)
output_states
+=
(
hidden_states
,)
if
self
.
nonlinearity
is
not
None
:
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
if
self
.
downsample
is
not
None
:
hidden_states
=
self
.
downsample
(
hidden_states
)
return
hidden_states
,
output_states
class
UpResnetBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
num_layers
=
1
,
temb_channels
=
32
,
groups
=
32
,
groups_out
=
None
,
non_linearity
=
None
,
time_embedding_norm
=
"default"
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
time_embedding_norm
=
time_embedding_norm
self
.
add_upsample
=
add_upsample
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
# there will always be at least one resnet
resnets
=
[
ResidualTemporalBlock1D
(
2
*
in_channels
,
out_channels
,
embed_dim
=
temb_channels
)]
for
_
in
range
(
num_layers
):
resnets
.
append
(
ResidualTemporalBlock1D
(
out_channels
,
out_channels
,
embed_dim
=
temb_channels
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
self
.
nonlinearity
=
None
self
.
upsample
=
None
if
add_upsample
:
self
.
upsample
=
Upsample1D
(
out_channels
,
use_conv_transpose
=
True
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
=
None
,
temb
=
None
):
if
res_hidden_states_tuple
is
not
None
:
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
((
hidden_states
,
res_hidden_states
),
dim
=
1
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
resnet
in
self
.
resnets
[
1
:]:
hidden_states
=
resnet
(
hidden_states
,
temb
)
if
self
.
nonlinearity
is
not
None
:
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
if
self
.
upsample
is
not
None
:
hidden_states
=
self
.
upsample
(
hidden_states
)
return
hidden_states
class
ValueFunctionMidBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
embed_dim
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
embed_dim
=
embed_dim
self
.
res1
=
ResidualTemporalBlock1D
(
in_channels
,
in_channels
//
2
,
embed_dim
=
embed_dim
)
self
.
down1
=
Downsample1D
(
out_channels
//
2
,
use_conv
=
True
)
self
.
res2
=
ResidualTemporalBlock1D
(
in_channels
//
2
,
in_channels
//
4
,
embed_dim
=
embed_dim
)
self
.
down2
=
Downsample1D
(
out_channels
//
4
,
use_conv
=
True
)
def
forward
(
self
,
x
,
temb
=
None
):
x
=
self
.
res1
(
x
,
temb
)
x
=
self
.
down1
(
x
)
x
=
self
.
res2
(
x
,
temb
)
x
=
self
.
down2
(
x
)
return
x
class
MidResTemporalBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
embed_dim
,
num_layers
:
int
=
1
,
add_downsample
:
bool
=
False
,
add_upsample
:
bool
=
False
,
non_linearity
=
None
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
add_downsample
=
add_downsample
# there will always be at least one resnet
resnets
=
[
ResidualTemporalBlock1D
(
in_channels
,
out_channels
,
embed_dim
=
embed_dim
)]
for
_
in
range
(
num_layers
):
resnets
.
append
(
ResidualTemporalBlock1D
(
out_channels
,
out_channels
,
embed_dim
=
embed_dim
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
self
.
nonlinearity
=
None
self
.
upsample
=
None
if
add_upsample
:
self
.
upsample
=
Downsample1D
(
out_channels
,
use_conv
=
True
)
self
.
downsample
=
None
if
add_downsample
:
self
.
downsample
=
Downsample1D
(
out_channels
,
use_conv
=
True
)
if
self
.
upsample
and
self
.
downsample
:
raise
ValueError
(
"Block cannot downsample and upsample"
)
def
forward
(
self
,
hidden_states
,
temb
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
resnet
in
self
.
resnets
[
1
:]:
hidden_states
=
resnet
(
hidden_states
,
temb
)
if
self
.
upsample
:
hidden_states
=
self
.
upsample
(
hidden_states
)
if
self
.
downsample
:
self
.
downsample
=
self
.
downsample
(
hidden_states
)
return
hidden_states
class
OutConv1DBlock
(
nn
.
Module
):
def
__init__
(
self
,
num_groups_out
,
out_channels
,
embed_dim
,
act_fn
):
super
().
__init__
()
self
.
final_conv1d_1
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
5
,
padding
=
2
)
self
.
final_conv1d_gn
=
nn
.
GroupNorm
(
num_groups_out
,
embed_dim
)
if
act_fn
==
"silu"
:
self
.
final_conv1d_act
=
nn
.
SiLU
()
if
act_fn
==
"mish"
:
self
.
final_conv1d_act
=
nn
.
Mish
()
self
.
final_conv1d_2
=
nn
.
Conv1d
(
embed_dim
,
out_channels
,
1
)
def
forward
(
self
,
hidden_states
,
temb
=
None
):
hidden_states
=
self
.
final_conv1d_1
(
hidden_states
)
hidden_states
=
rearrange_dims
(
hidden_states
)
hidden_states
=
self
.
final_conv1d_gn
(
hidden_states
)
hidden_states
=
rearrange_dims
(
hidden_states
)
hidden_states
=
self
.
final_conv1d_act
(
hidden_states
)
hidden_states
=
self
.
final_conv1d_2
(
hidden_states
)
return
hidden_states
class
OutValueFunctionBlock
(
nn
.
Module
):
def
__init__
(
self
,
fc_dim
,
embed_dim
):
super
().
__init__
()
self
.
final_block
=
nn
.
ModuleList
(
[
nn
.
Linear
(
fc_dim
+
embed_dim
,
fc_dim
//
2
),
nn
.
Mish
(),
nn
.
Linear
(
fc_dim
//
2
,
1
),
]
)
def
forward
(
self
,
hidden_states
,
temb
):
hidden_states
=
hidden_states
.
view
(
hidden_states
.
shape
[
0
],
-
1
)
hidden_states
=
torch
.
cat
((
hidden_states
,
temb
),
dim
=-
1
)
for
layer
in
self
.
final_block
:
hidden_states
=
layer
(
hidden_states
)
return
hidden_states
_kernels
=
{
_kernels
=
{
"linear"
:
[
1
/
8
,
3
/
8
,
3
/
8
,
1
/
8
],
"linear"
:
[
1
/
8
,
3
/
8
,
3
/
8
,
1
/
8
],
...
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
...
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
self
.
pad
=
kernel_1d
.
shape
[
0
]
//
2
-
1
self
.
pad
=
kernel_1d
.
shape
[
0
]
//
2
-
1
self
.
register_buffer
(
"kernel"
,
kernel_1d
)
self
.
register_buffer
(
"kernel"
,
kernel_1d
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
,
temb
=
None
):
hidden_states
=
F
.
pad
(
hidden_states
,
((
self
.
pad
+
1
)
//
2
,)
*
2
,
self
.
pad_mode
)
hidden_states
=
F
.
pad
(
hidden_states
,
((
self
.
pad
+
1
)
//
2
,)
*
2
,
self
.
pad_mode
)
weight
=
hidden_states
.
new_zeros
([
hidden_states
.
shape
[
1
],
hidden_states
.
shape
[
1
],
self
.
kernel
.
shape
[
0
]])
weight
=
hidden_states
.
new_zeros
([
hidden_states
.
shape
[
1
],
hidden_states
.
shape
[
1
],
self
.
kernel
.
shape
[
0
]])
indices
=
torch
.
arange
(
hidden_states
.
shape
[
1
],
device
=
hidden_states
.
device
)
indices
=
torch
.
arange
(
hidden_states
.
shape
[
1
],
device
=
hidden_states
.
device
)
...
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
...
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
return
output
return
output
def
get_down_block
(
down_block_type
,
out_channels
,
in_channels
):
if
down_block_type
==
"DownBlock1D"
:
return
DownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"AttnDownBlock1D"
:
return
AttnDownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"DownBlock1DNoSkip"
:
return
DownBlock1DNoSkip
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
raise
ValueError
(
f
"
{
down_block_type
}
does not exist."
)
def
get_up_block
(
up_block_type
,
in_channels
,
out_channels
):
if
up_block_type
==
"UpBlock1D"
:
return
UpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"AttnUpBlock1D"
:
return
AttnUpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"UpBlock1DNoSkip"
:
return
UpBlock1DNoSkip
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
up_block_type
}
does not exist."
)
def
get_mid_block
(
mid_block_type
,
in_channels
,
mid_channels
,
out_channels
):
if
mid_block_type
==
"UNetMidBlock1D"
:
return
UNetMidBlock1D
(
in_channels
=
in_channels
,
mid_channels
=
mid_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
mid_block_type
}
does not exist."
)
class
UNetMidBlock1D
(
nn
.
Module
):
class
UNetMidBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
mid_channels
,
in_channels
,
out_channels
=
None
):
def
__init__
(
self
,
mid_channels
,
in_channels
,
out_channels
=
None
):
super
().
__init__
()
super
().
__init__
()
...
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
...
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
,
temb
=
None
):
hidden_states
=
self
.
down
(
hidden_states
)
hidden_states
=
self
.
down
(
hidden_states
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
):
hidden_states
=
resnet
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
)
...
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
...
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
):
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
...
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
...
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
):
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
...
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
...
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
):
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
...
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
...
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
hidden_states
=
resnet
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
)
return
hidden_states
return
hidden_states
def
get_down_block
(
down_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_downsample
):
if
down_block_type
==
"DownResnetBlock1D"
:
return
DownResnetBlock1D
(
in_channels
=
in_channels
,
num_layers
=
num_layers
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
)
elif
down_block_type
==
"DownBlock1D"
:
return
DownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"AttnDownBlock1D"
:
return
AttnDownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"DownBlock1DNoSkip"
:
return
DownBlock1DNoSkip
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
raise
ValueError
(
f
"
{
down_block_type
}
does not exist."
)
def
get_up_block
(
up_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_upsample
):
if
up_block_type
==
"UpResnetBlock1D"
:
return
UpResnetBlock1D
(
in_channels
=
in_channels
,
num_layers
=
num_layers
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
)
elif
up_block_type
==
"UpBlock1D"
:
return
UpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"AttnUpBlock1D"
:
return
AttnUpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"UpBlock1DNoSkip"
:
return
UpBlock1DNoSkip
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
up_block_type
}
does not exist."
)
def
get_mid_block
(
mid_block_type
,
num_layers
,
in_channels
,
mid_channels
,
out_channels
,
embed_dim
,
add_downsample
):
if
mid_block_type
==
"MidResTemporalBlock1D"
:
return
MidResTemporalBlock1D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
embed_dim
=
embed_dim
,
add_downsample
=
add_downsample
,
)
elif
mid_block_type
==
"ValueFunctionMidBlock1D"
:
return
ValueFunctionMidBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
embed_dim
=
embed_dim
)
elif
mid_block_type
==
"UNetMidBlock1D"
:
return
UNetMidBlock1D
(
in_channels
=
in_channels
,
mid_channels
=
mid_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
mid_block_type
}
does not exist."
)
def
get_out_block
(
*
,
out_block_type
,
num_groups_out
,
embed_dim
,
out_channels
,
act_fn
,
fc_dim
):
if
out_block_type
==
"OutConv1DBlock"
:
return
OutConv1DBlock
(
num_groups_out
,
out_channels
,
embed_dim
,
act_fn
)
elif
out_block_type
==
"ValueFunction"
:
return
OutValueFunctionBlock
(
fc_dim
,
embed_dim
)
return
None
src/diffusers/models/unet_2d.py
View file @
554b374d
...
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`
Fals
e`): Whether to flip sin to cos for fourier time embedding.
obj:`
Tru
e`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
types.
...
...
src/diffusers/models/unet_2d_condition.py
View file @
554b374d
...
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `
Fals
e`):
flip_sin_to_cos (`bool`, *optional*, defaults to `
Tru
e`):
Whether to flip the sin to cos in the time embedding.
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
...
...
src/diffusers/pipeline_flax_utils.py
View file @
554b374d
...
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
...
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"FlaxModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxSchedulerMixin"
:
[
"save_
config"
,
"from_config
"
],
"FlaxSchedulerMixin"
:
[
"save_
pretrained"
,
"from_pretrained
"
],
"FlaxDiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxDiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
"transformers"
:
{
"transformers"
:
{
...
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_
config
(
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_
pretrained
(
... model_id,
... model_id,
... subfolder="scheduler",
... subfolder="scheduler",
... )
... )
...
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
config_dict
=
cls
.
get
_config
_dict
(
config_dict
=
cls
.
load
_config
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get
_config
_dict
(
cached_folder
)
config_dict
=
cls
.
load
_config
(
cached_folder
)
# 2. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
...
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
init_dict
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
_
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_kwargs
=
{}
init_kwargs
=
{}
...
...
src/diffusers/pipeline_utils.py
View file @
554b374d
...
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
...
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"SchedulerMixin"
:
[
"save_
config"
,
"from_config
"
],
"SchedulerMixin"
:
[
"save_
pretrained"
,
"from_pretrained
"
],
"DiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"DiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"OnnxRuntimeModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"OnnxRuntimeModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
...
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
if
torch_device
is
None
:
if
torch_device
is
None
:
return
self
return
self
module_names
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
module_names
,
_
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
for
name
in
module_names
.
keys
():
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
...
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
Returns:
Returns:
`torch.device`: The torch device on which the pipeline is located.
`torch.device`: The torch device on which the pipeline is located.
"""
"""
module_names
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
module_names
,
_
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
for
name
in
module_names
.
keys
():
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
...
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> #
Download pipeline, but overwrite
scheduler
>>> #
Use a different
scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config(
"runwayml/stable-diffusion-v1-5", subfolder="
scheduler
"
)
>>> scheduler = LMSDiscreteScheduler.from_config(
pipeline.
scheduler
.config
)
>>> pipeline
= DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
scheduler
=
scheduler
)
>>> pipeline
.
scheduler
=
scheduler
```
```
"""
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
...
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
config_dict
=
cls
.
get
_config
_dict
(
config_dict
=
cls
.
load
_config
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get
_config
_dict
(
cached_folder
)
config_dict
=
cls
.
load
_config
(
cached_folder
)
# 2. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
...
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
-
set
([
"self"
])
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
-
set
([
"self"
])
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
init_dict
,
unused_kwargs
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
unused_kwargs
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
if
len
(
unused_kwargs
)
>
0
:
if
len
(
unused_kwargs
)
>
0
:
logger
.
warning
(
f
"Keyword arguments
{
unused_kwargs
}
not recognized."
)
logger
.
warning
(
f
"Keyword arguments
{
unused_kwargs
}
not recognized."
)
...
...
src/diffusers/pipelines/README.md
View file @
554b374d
...
@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
...
@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
|
[
pndm
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm
)
|
[
**Pseudo Numerical Methods for Diffusion Models on Manifolds**
](
https://arxiv.org/abs/2202.09778
)
|
*Unconditional Image Generation*
|
|
[
pndm
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm
)
|
[
**Pseudo Numerical Methods for Diffusion Models on Manifolds**
](
https://arxiv.org/abs/2202.09778
)
|
*Unconditional Image Generation*
|
|
[
score_sde_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
score_sde_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
score_sde_vp
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
score_sde_vp
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-to-Image Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/
training_example
.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-to-Image Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/
stable_diffusion
.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Image-to-Image Text-Guided Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Image-to-Image Text-Guided Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-Guided Image Inpainting*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-Guided Image Inpainting*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
[
stochastic_karras_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve
)
|
[
**Elucidating the Design Space of Diffusion-Based Generative Models**
](
https://arxiv.org/abs/2206.00364
)
|
*Unconditional Image Generation*
|
|
[
stochastic_karras_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve
)
|
[
**Elucidating the Design Space of Diffusion-Based Generative Models**
](
https://arxiv.org/abs/2206.00364
)
|
*Unconditional Image Generation*
|
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
554b374d
...
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
"""
"""
message
=
(
message
=
(
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_
config
(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_
pretrained
(<model_id>, predict_epsilon=True)`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
...
...
src/diffusers/pipelines/stable_diffusion/README.md
View file @
554b374d
...
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
...
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
# make sure you're logged in with `huggingface-cli login`
from
diffusers
import
StableDiffusionPipeline
,
DDIMScheduler
from
diffusers
import
StableDiffusionPipeline
,
DDIMScheduler
scheduler
=
DDIMScheduler
.
from_
config
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
"runwayml/stable-diffusion-v1-5"
,
...
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
...
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
# make sure you're logged in with `huggingface-cli login`
from
diffusers
import
StableDiffusionPipeline
,
LMSDiscreteScheduler
from
diffusers
import
StableDiffusionPipeline
,
LMSDiscreteScheduler
lms
=
LMSDiscreteScheduler
.
from_
config
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
"runwayml/stable-diffusion-v1-5"
,
...
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
...
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
# make sure you're logged in with `huggingface-cli login`
model_id_or_path
=
"CompVis/stable-diffusion-v1-4"
model_id_or_path
=
"CompVis/stable-diffusion-v1-4"
scheduler
=
DDIMScheduler
.
from_
config
(
model_id_or_path
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id_or_path
,
subfolder
=
"scheduler"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id_or_path
,
scheduler
=
scheduler
).
to
(
"cuda"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id_or_path
,
scheduler
=
scheduler
).
to
(
"cuda"
)
# let's download an initial image
# let's download an initial image
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
554b374d
...
@@ -23,7 +23,7 @@ import numpy as np
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
For more details, see the original paper: https://arxiv.org/abs/2010.02502
...
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"PNDMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
554b374d
...
@@ -23,7 +23,12 @@ import flax
...
@@ -23,7 +23,12 @@ import flax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
For more details, see the original paper: https://arxiv.org/abs/2010.02502
...
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
stable diffusion.
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
554b374d
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
,
deprecate
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
For more details, see the original paper: https://arxiv.org/abs/2006.11239
...
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"PNDMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# for rl-diffuser https://arxiv.org/abs/2205.09991
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
variance_type
==
"fixed_small_log"
:
elif
variance_type
==
"fixed_small_log"
:
variance
=
torch
.
log
(
torch
.
clamp
(
variance
,
min
=
1e-20
))
variance
=
torch
.
log
(
torch
.
clamp
(
variance
,
min
=
1e-20
))
variance
=
torch
.
exp
(
0.5
*
variance
)
elif
variance_type
==
"fixed_large"
:
elif
variance_type
==
"fixed_large"
:
variance
=
self
.
betas
[
t
]
variance
=
self
.
betas
[
t
]
elif
variance_type
==
"fixed_large_log"
:
elif
variance_type
==
"fixed_large_log"
:
...
@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
message
=
(
message
=
(
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_
config
(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_
pretrained
(<model_id>, predict_epsilon=True)`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
...
@@ -301,7 +295,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -301,7 +295,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_noise
=
torch
.
randn
(
variance_noise
=
torch
.
randn
(
model_output
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
model_output
.
dtype
model_output
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
model_output
.
dtype
)
)
variance
=
(
self
.
_get_variance
(
t
,
predicted_variance
=
predicted_variance
)
**
0.5
)
*
variance_noise
if
self
.
variance_type
==
"fixed_small_log"
:
variance
=
self
.
_get_variance
(
t
,
predicted_variance
=
predicted_variance
)
*
variance_noise
else
:
variance
=
(
self
.
_get_variance
(
t
,
predicted_variance
=
predicted_variance
)
**
0.5
)
*
variance_noise
pred_prev_sample
=
pred_prev_sample
+
variance
pred_prev_sample
=
pred_prev_sample
+
variance
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
554b374d
...
@@ -24,7 +24,12 @@ from jax import random
...
@@ -24,7 +24,12 @@ from jax import random
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..utils
import
deprecate
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
For more details, see the original paper: https://arxiv.org/abs/2006.11239
...
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
message
=
(
message
=
(
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_
config
(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_
pretrained
(<model_id>, predict_epsilon=True)`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
554b374d
...
@@ -21,6 +21,7 @@ import numpy as np
...
@@ -21,6 +21,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
...
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"DDPMScheduler"
,
"PNDMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
554b374d
...
@@ -23,7 +23,12 @@ import jax
...
@@ -23,7 +23,12 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
...
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment