Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
554b374d
Commit
554b374d
authored
Nov 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d5ab55e4
a0520193
Changes
76
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
756 additions
and
122 deletions
+756
-122
src/diffusers/experimental/__init__.py
src/diffusers/experimental/__init__.py
+1
-0
src/diffusers/experimental/rl/__init__.py
src/diffusers/experimental/rl/__init__.py
+1
-0
src/diffusers/experimental/rl/value_guided_sampling.py
src/diffusers/experimental/rl/value_guided_sampling.py
+129
-0
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+10
-3
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+136
-2
src/diffusers/models/unet_1d.py
src/diffusers/models/unet_1d.py
+94
-21
src/diffusers/models/unet_1d_blocks.py
src/diffusers/models/unet_1d_blocks.py
+315
-31
src/diffusers/models/unet_2d.py
src/diffusers/models/unet_2d.py
+1
-1
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+1
-1
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+5
-5
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+9
-9
src/diffusers/pipelines/README.md
src/diffusers/pipelines/README.md
+1
-1
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+1
-1
src/diffusers/pipelines/stable_diffusion/README.md
src/diffusers/pipelines/stable_diffusion/README.md
+3
-3
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+4
-11
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+10
-3
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+10
-13
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+11
-4
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+4
-10
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+10
-3
No files found.
src/diffusers/experimental/__init__.py
0 → 100644
View file @
554b374d
from
.rl
import
ValueGuidedRLPipeline
src/diffusers/experimental/rl/__init__.py
0 → 100644
View file @
554b374d
from
.value_guided_sampling
import
ValueGuidedRLPipeline
src/diffusers/experimental/rl/value_guided_sampling.py
0 → 100644
View file @
554b374d
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
torch
import
tqdm
from
...models.unet_1d
import
UNet1DModel
from
...pipeline_utils
import
DiffusionPipeline
from
...utils.dummy_pt_objects
import
DDPMScheduler
class
ValueGuidedRLPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
value_function
:
UNet1DModel
,
unet
:
UNet1DModel
,
scheduler
:
DDPMScheduler
,
env
,
):
super
().
__init__
()
self
.
value_function
=
value_function
self
.
unet
=
unet
self
.
scheduler
=
scheduler
self
.
env
=
env
self
.
data
=
env
.
get_dataset
()
self
.
means
=
dict
()
for
key
in
self
.
data
.
keys
():
try
:
self
.
means
[
key
]
=
self
.
data
[
key
].
mean
()
except
:
pass
self
.
stds
=
dict
()
for
key
in
self
.
data
.
keys
():
try
:
self
.
stds
[
key
]
=
self
.
data
[
key
].
std
()
except
:
pass
self
.
state_dim
=
env
.
observation_space
.
shape
[
0
]
self
.
action_dim
=
env
.
action_space
.
shape
[
0
]
def
normalize
(
self
,
x_in
,
key
):
return
(
x_in
-
self
.
means
[
key
])
/
self
.
stds
[
key
]
def
de_normalize
(
self
,
x_in
,
key
):
return
x_in
*
self
.
stds
[
key
]
+
self
.
means
[
key
]
def
to_torch
(
self
,
x_in
):
if
type
(
x_in
)
is
dict
:
return
{
k
:
self
.
to_torch
(
v
)
for
k
,
v
in
x_in
.
items
()}
elif
torch
.
is_tensor
(
x_in
):
return
x_in
.
to
(
self
.
unet
.
device
)
return
torch
.
tensor
(
x_in
,
device
=
self
.
unet
.
device
)
def
reset_x0
(
self
,
x_in
,
cond
,
act_dim
):
for
key
,
val
in
cond
.
items
():
x_in
[:,
key
,
act_dim
:]
=
val
.
clone
()
return
x_in
def
run_diffusion
(
self
,
x
,
conditions
,
n_guide_steps
,
scale
):
batch_size
=
x
.
shape
[
0
]
y
=
None
for
i
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# create batch of timesteps to pass into model
timesteps
=
torch
.
full
((
batch_size
,),
i
,
device
=
self
.
unet
.
device
,
dtype
=
torch
.
long
)
for
_
in
range
(
n_guide_steps
):
with
torch
.
enable_grad
():
x
.
requires_grad_
()
y
=
self
.
value_function
(
x
.
permute
(
0
,
2
,
1
),
timesteps
).
sample
grad
=
torch
.
autograd
.
grad
([
y
.
sum
()],
[
x
])[
0
]
posterior_variance
=
self
.
scheduler
.
_get_variance
(
i
)
model_std
=
torch
.
exp
(
0.5
*
posterior_variance
)
grad
=
model_std
*
grad
grad
[
timesteps
<
2
]
=
0
x
=
x
.
detach
()
x
=
x
+
scale
*
grad
x
=
self
.
reset_x0
(
x
,
conditions
,
self
.
action_dim
)
prev_x
=
self
.
unet
(
x
.
permute
(
0
,
2
,
1
),
timesteps
).
sample
.
permute
(
0
,
2
,
1
)
x
=
self
.
scheduler
.
step
(
prev_x
,
i
,
x
,
predict_epsilon
=
False
)[
"prev_sample"
]
# apply conditions to the trajectory
x
=
self
.
reset_x0
(
x
,
conditions
,
self
.
action_dim
)
x
=
self
.
to_torch
(
x
)
return
x
,
y
def
__call__
(
self
,
obs
,
batch_size
=
64
,
planning_horizon
=
32
,
n_guide_steps
=
2
,
scale
=
0.1
):
# normalize the observations and create batch dimension
obs
=
self
.
normalize
(
obs
,
"observations"
)
obs
=
obs
[
None
].
repeat
(
batch_size
,
axis
=
0
)
conditions
=
{
0
:
self
.
to_torch
(
obs
)}
shape
=
(
batch_size
,
planning_horizon
,
self
.
state_dim
+
self
.
action_dim
)
# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1
=
torch
.
randn
(
shape
,
device
=
self
.
unet
.
device
)
x
=
self
.
reset_x0
(
x1
,
conditions
,
self
.
action_dim
)
x
=
self
.
to_torch
(
x
)
# run the diffusion process
x
,
y
=
self
.
run_diffusion
(
x
,
conditions
,
n_guide_steps
,
scale
)
# sort output trajectories by value
sorted_idx
=
y
.
argsort
(
0
,
descending
=
True
).
squeeze
()
sorted_values
=
x
[
sorted_idx
]
actions
=
sorted_values
[:,
:,
:
self
.
action_dim
]
actions
=
actions
.
detach
().
cpu
().
numpy
()
denorm_actions
=
self
.
de_normalize
(
actions
,
key
=
"actions"
)
# select the action with the highest value
if
y
is
not
None
:
selected_index
=
0
else
:
# if we didn't run value guiding, select a random action
selected_index
=
np
.
random
.
randint
(
0
,
batch_size
)
denorm_actions
=
denorm_actions
[
selected_index
,
0
]
return
denorm_actions
src/diffusers/models/embeddings.py
View file @
554b374d
...
@@ -62,14 +62,21 @@ def get_timestep_embedding(
...
@@ -62,14 +62,21 @@ def get_timestep_embedding(
class
TimestepEmbedding
(
nn
.
Module
):
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
:
int
,
time_embed_dim
:
int
,
act_fn
:
str
=
"silu"
):
def
__init__
(
self
,
in_
channel
s
:
int
,
time_embed_dim
:
int
,
act_fn
:
str
=
"silu"
,
out_dim
:
int
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
channel
,
time_embed_dim
)
self
.
linear_1
=
nn
.
Linear
(
in_
channel
s
,
time_embed_dim
)
self
.
act
=
None
self
.
act
=
None
if
act_fn
==
"silu"
:
if
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
self
.
act
=
nn
.
SiLU
()
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
)
elif
act_fn
==
"mish"
:
self
.
act
=
nn
.
Mish
()
if
out_dim
is
not
None
:
time_embed_dim_out
=
out_dim
else
:
time_embed_dim_out
=
time_embed_dim
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim_out
)
def
forward
(
self
,
sample
):
def
forward
(
self
,
sample
):
sample
=
self
.
linear_1
(
sample
)
sample
=
self
.
linear_1
(
sample
)
...
...
src/diffusers/models/resnet.py
View file @
554b374d
...
@@ -5,6 +5,75 @@ import torch.nn as nn
...
@@ -5,6 +5,75 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
class
Upsample1D
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
self
.
conv
=
None
if
use_conv_transpose
:
self
.
conv
=
nn
.
ConvTranspose1d
(
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
nn
.
Conv1d
(
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample1D
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
2
self
.
name
=
name
if
use_conv
:
self
.
conv
=
nn
.
Conv1d
(
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
conv
=
nn
.
AvgPool1d
(
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
conv
(
x
)
class
Upsample2D
(
nn
.
Module
):
class
Upsample2D
(
nn
.
Module
):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
...
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
...
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
Parameters:
Parameters:
channels: channels in the inputs and outputs.
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
use_conv_transpose:
out_channels:
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
out_channels
=
None
,
name
=
"conv"
):
...
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
...
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
Parameters:
Parameters:
channels: channels in the inputs and outputs.
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
out_channels:
padding:
"""
"""
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
out_channels
=
None
,
padding
=
1
,
name
=
"conv"
):
...
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
...
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
return
hidden_states
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
hidden_states
))
# unet_rl.py
def
rearrange_dims
(
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
class
Conv1dBlock
(
nn
.
Module
):
"""
Conv1d --> GroupNorm --> Mish
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
conv1d
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
group_norm
=
nn
.
GroupNorm
(
n_groups
,
out_channels
)
self
.
mish
=
nn
.
Mish
()
def
forward
(
self
,
x
):
x
=
self
.
conv1d
(
x
)
x
=
rearrange_dims
(
x
)
x
=
self
.
group_norm
(
x
)
x
=
rearrange_dims
(
x
)
x
=
self
.
mish
(
x
)
return
x
# unet_rl.py
class
ResidualTemporalBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
kernel_size
=
5
):
super
().
__init__
()
self
.
conv_in
=
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
)
self
.
conv_out
=
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
)
self
.
time_emb_act
=
nn
.
Mish
()
self
.
time_emb
=
nn
.
Linear
(
embed_dim
,
out_channels
)
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
"""
Args:
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
t
=
self
.
time_emb_act
(
t
)
t
=
self
.
time_emb
(
t
)
out
=
self
.
conv_in
(
x
)
+
rearrange_dims
(
t
)
out
=
self
.
conv_out
(
out
)
return
out
+
self
.
residual_conv
(
x
)
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_2d
(
hidden_states
,
kernel
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample2D a batch of 2D images with the given filter.
r
"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
...
...
src/diffusers/models/unet_1d.py
View file @
554b374d
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
...
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
...
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
..utils
import
BaseOutput
from
..utils
import
BaseOutput
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.embeddings
import
GaussianFourierProjection
,
TimestepEmbedding
,
Timesteps
from
.unet_1d_blocks
import
get_down_block
,
get_mid_block
,
get_up_block
from
.unet_1d_blocks
import
get_down_block
,
get_mid_block
,
get_out_block
,
get_up_block
@
dataclass
@
dataclass
...
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
implements for all the model (such as downloading or saving, etc.)
Parameters:
Parameters:
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
sample_size (`int`, *option
a
l*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`
in
t`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
freq_shift (`
floa
t`, *optional*, defaults to
0.
0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
down_block_types (`Tuple[str]`, *optional*, defaults to :
...
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
obj:`(32, 32, 64)`): Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
downsample_each_block (`int`, *optional*, defaults to False:
experimental feature for using a UNet without upsampling.
"""
"""
@
register_to_config
@
register_to_config
...
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels
:
int
=
2
,
out_channels
:
int
=
2
,
extra_in_channels
:
int
=
0
,
extra_in_channels
:
int
=
0
,
time_embedding_type
:
str
=
"fourier"
,
time_embedding_type
:
str
=
"fourier"
,
freq_shift
:
int
=
0
,
flip_sin_to_cos
:
bool
=
True
,
flip_sin_to_cos
:
bool
=
True
,
use_timestep_embedding
:
bool
=
False
,
use_timestep_embedding
:
bool
=
False
,
freq_shift
:
float
=
0.0
,
down_block_types
:
Tuple
[
str
]
=
(
"DownBlock1DNoSkip"
,
"DownBlock1D"
,
"AttnDownBlock1D"
),
down_block_types
:
Tuple
[
str
]
=
(
"DownBlock1DNoSkip"
,
"DownBlock1D"
,
"AttnDownBlock1D"
),
mid_block_type
:
str
=
"UNetMidBlock1D"
,
up_block_types
:
Tuple
[
str
]
=
(
"AttnUpBlock1D"
,
"UpBlock1D"
,
"UpBlock1DNoSkip"
),
up_block_types
:
Tuple
[
str
]
=
(
"AttnUpBlock1D"
,
"UpBlock1D"
,
"UpBlock1DNoSkip"
),
mid_block_type
:
Tuple
[
str
]
=
"UNetMidBlock1D"
,
out_block_type
:
str
=
None
,
block_out_channels
:
Tuple
[
int
]
=
(
32
,
32
,
64
),
block_out_channels
:
Tuple
[
int
]
=
(
32
,
32
,
64
),
act_fn
:
str
=
None
,
norm_num_groups
:
int
=
8
,
layers_per_block
:
int
=
1
,
downsample_each_block
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
sample_size
=
sample_size
self
.
sample_size
=
sample_size
# time
# time
...
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
)
)
timestep_input_dim
=
2
*
block_out_channels
[
0
]
timestep_input_dim
=
2
*
block_out_channels
[
0
]
elif
time_embedding_type
==
"positional"
:
elif
time_embedding_type
==
"positional"
:
self
.
time_proj
=
Timesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
,
freq_shift
)
self
.
time_proj
=
Timesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
freq_shift
)
timestep_input_dim
=
block_out_channels
[
0
]
timestep_input_dim
=
block_out_channels
[
0
]
if
use_timestep_embedding
:
if
use_timestep_embedding
:
time_embed_dim
=
block_out_channels
[
0
]
*
4
time_embed_dim
=
block_out_channels
[
0
]
*
4
self
.
time_embedding
=
TimestepEmbedding
(
timestep_input_dim
,
time_embed_dim
)
self
.
time_mlp
=
TimestepEmbedding
(
in_channels
=
timestep_input_dim
,
time_embed_dim
=
time_embed_dim
,
act_fn
=
act_fn
,
out_dim
=
block_out_channels
[
0
],
)
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
down_blocks
=
nn
.
ModuleList
([])
self
.
mid_block
=
None
self
.
mid_block
=
None
...
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
if
i
==
0
:
if
i
==
0
:
input_channel
+=
extra_in_channels
input_channel
+=
extra_in_channels
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
down_block
=
get_down_block
(
down_block
=
get_down_block
(
down_block_type
,
down_block_type
,
num_layers
=
layers_per_block
,
in_channels
=
input_channel
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
out_channels
=
output_channel
,
temb_channels
=
block_out_channels
[
0
],
add_downsample
=
not
is_final_block
or
downsample_each_block
,
)
)
self
.
down_blocks
.
append
(
down_block
)
self
.
down_blocks
.
append
(
down_block
)
# mid
# mid
self
.
mid_block
=
get_mid_block
(
self
.
mid_block
=
get_mid_block
(
mid_block_type
=
mid_block_type
,
mid_block_type
,
mid_channels
=
block_out_channels
[
-
1
],
in_channels
=
block_out_channels
[
-
1
],
in_channels
=
block_out_channels
[
-
1
],
out_channels
=
None
,
mid_channels
=
block_out_channels
[
-
1
],
out_channels
=
block_out_channels
[
-
1
],
embed_dim
=
block_out_channels
[
0
],
num_layers
=
layers_per_block
,
add_downsample
=
downsample_each_block
,
)
)
# up
# up
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
output_channel
=
reversed_block_out_channels
[
0
]
output_channel
=
reversed_block_out_channels
[
0
]
if
out_block_type
is
None
:
final_upsample_channels
=
out_channels
else
:
final_upsample_channels
=
block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
prev_output_channel
=
output_channel
prev_output_channel
=
output_channel
output_channel
=
reversed_block_out_channels
[
i
+
1
]
if
i
<
len
(
up_block_types
)
-
1
else
out_channels
output_channel
=
(
reversed_block_out_channels
[
i
+
1
]
if
i
<
len
(
up_block_types
)
-
1
else
final_upsample_channels
)
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
up_block
=
get_up_block
(
up_block
=
get_up_block
(
up_block_type
,
up_block_type
,
num_layers
=
layers_per_block
,
in_channels
=
prev_output_channel
,
in_channels
=
prev_output_channel
,
out_channels
=
output_channel
,
out_channels
=
output_channel
,
temb_channels
=
block_out_channels
[
0
],
add_upsample
=
not
is_final_block
,
)
)
self
.
up_blocks
.
append
(
up_block
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
prev_output_channel
=
output_channel
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
# out
# Totally fine to add another layer with a if statement - no need for nn.Identity here
num_groups_out
=
norm_num_groups
if
norm_num_groups
is
not
None
else
min
(
block_out_channels
[
0
]
//
4
,
32
)
self
.
out_block
=
get_out_block
(
out_block_type
=
out_block_type
,
num_groups_out
=
num_groups_out
,
embed_dim
=
block_out_channels
[
0
],
out_channels
=
out_channels
,
act_fn
=
act_fn
,
fc_dim
=
block_out_channels
[
-
1
]
//
4
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -144,11 +204,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -144,11 +204,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
"""
# 1. time
if
len
(
timestep
.
shape
)
==
0
:
timestep
=
timestep
[
None
]
timestep_embed
=
self
.
time_proj
(
timestep
)[...,
None
]
# 1. time
timesteps
=
timestep
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
timestep_embed
=
self
.
time_proj
(
timesteps
)
if
self
.
config
.
use_timestep_embedding
:
timestep_embed
=
self
.
time_mlp
(
timestep_embed
)
else
:
timestep_embed
=
timestep_embed
[...,
None
]
timestep_embed
=
timestep_embed
.
repeat
([
1
,
1
,
sample
.
shape
[
2
]]).
to
(
sample
.
dtype
)
timestep_embed
=
timestep_embed
.
repeat
([
1
,
1
,
sample
.
shape
[
2
]]).
to
(
sample
.
dtype
)
# 2. down
# 2. down
...
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
...
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
# 3. mid
# 3. mid
sample
=
self
.
mid_block
(
sample
)
if
self
.
mid_block
:
sample
=
self
.
mid_block
(
sample
,
timestep_embed
)
# 4. up
# 4. up
for
i
,
upsample_block
in
enumerate
(
self
.
up_blocks
):
for
i
,
upsample_block
in
enumerate
(
self
.
up_blocks
):
res_samples
=
down_block_res_samples
[
-
1
:]
res_samples
=
down_block_res_samples
[
-
1
:]
down_block_res_samples
=
down_block_res_samples
[:
-
1
]
down_block_res_samples
=
down_block_res_samples
[:
-
1
]
sample
=
upsample_block
(
sample
,
res_samples
)
sample
=
upsample_block
(
sample
,
res_hidden_states_tuple
=
res_samples
,
temb
=
timestep_embed
)
# 5. post-process
if
self
.
out_block
:
sample
=
self
.
out_block
(
sample
,
timestep_embed
)
if
not
return_dict
:
if
not
return_dict
:
return
(
sample
,)
return
(
sample
,)
...
...
src/diffusers/models/unet_1d_blocks.py
View file @
554b374d
...
@@ -17,6 +17,256 @@ import torch
...
@@ -17,6 +17,256 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
.resnet
import
Downsample1D
,
ResidualTemporalBlock1D
,
Upsample1D
,
rearrange_dims
class
DownResnetBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
num_layers
=
1
,
conv_shortcut
=
False
,
temb_channels
=
32
,
groups
=
32
,
groups_out
=
None
,
non_linearity
=
None
,
time_embedding_norm
=
"default"
,
output_scale_factor
=
1.0
,
add_downsample
=
True
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
add_downsample
=
add_downsample
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
# there will always be at least one resnet
resnets
=
[
ResidualTemporalBlock1D
(
in_channels
,
out_channels
,
embed_dim
=
temb_channels
)]
for
_
in
range
(
num_layers
):
resnets
.
append
(
ResidualTemporalBlock1D
(
out_channels
,
out_channels
,
embed_dim
=
temb_channels
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
self
.
nonlinearity
=
None
self
.
downsample
=
None
if
add_downsample
:
self
.
downsample
=
Downsample1D
(
out_channels
,
use_conv
=
True
,
padding
=
1
)
def
forward
(
self
,
hidden_states
,
temb
=
None
):
output_states
=
()
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
resnet
in
self
.
resnets
[
1
:]:
hidden_states
=
resnet
(
hidden_states
,
temb
)
output_states
+=
(
hidden_states
,)
if
self
.
nonlinearity
is
not
None
:
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
if
self
.
downsample
is
not
None
:
hidden_states
=
self
.
downsample
(
hidden_states
)
return
hidden_states
,
output_states
class
UpResnetBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
num_layers
=
1
,
temb_channels
=
32
,
groups
=
32
,
groups_out
=
None
,
non_linearity
=
None
,
time_embedding_norm
=
"default"
,
output_scale_factor
=
1.0
,
add_upsample
=
True
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
time_embedding_norm
=
time_embedding_norm
self
.
add_upsample
=
add_upsample
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
# there will always be at least one resnet
resnets
=
[
ResidualTemporalBlock1D
(
2
*
in_channels
,
out_channels
,
embed_dim
=
temb_channels
)]
for
_
in
range
(
num_layers
):
resnets
.
append
(
ResidualTemporalBlock1D
(
out_channels
,
out_channels
,
embed_dim
=
temb_channels
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
self
.
nonlinearity
=
None
self
.
upsample
=
None
if
add_upsample
:
self
.
upsample
=
Upsample1D
(
out_channels
,
use_conv_transpose
=
True
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
=
None
,
temb
=
None
):
if
res_hidden_states_tuple
is
not
None
:
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
((
hidden_states
,
res_hidden_states
),
dim
=
1
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
resnet
in
self
.
resnets
[
1
:]:
hidden_states
=
resnet
(
hidden_states
,
temb
)
if
self
.
nonlinearity
is
not
None
:
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
if
self
.
upsample
is
not
None
:
hidden_states
=
self
.
upsample
(
hidden_states
)
return
hidden_states
class
ValueFunctionMidBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
embed_dim
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
embed_dim
=
embed_dim
self
.
res1
=
ResidualTemporalBlock1D
(
in_channels
,
in_channels
//
2
,
embed_dim
=
embed_dim
)
self
.
down1
=
Downsample1D
(
out_channels
//
2
,
use_conv
=
True
)
self
.
res2
=
ResidualTemporalBlock1D
(
in_channels
//
2
,
in_channels
//
4
,
embed_dim
=
embed_dim
)
self
.
down2
=
Downsample1D
(
out_channels
//
4
,
use_conv
=
True
)
def
forward
(
self
,
x
,
temb
=
None
):
x
=
self
.
res1
(
x
,
temb
)
x
=
self
.
down1
(
x
)
x
=
self
.
res2
(
x
,
temb
)
x
=
self
.
down2
(
x
)
return
x
class
MidResTemporalBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
embed_dim
,
num_layers
:
int
=
1
,
add_downsample
:
bool
=
False
,
add_upsample
:
bool
=
False
,
non_linearity
=
None
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
add_downsample
=
add_downsample
# there will always be at least one resnet
resnets
=
[
ResidualTemporalBlock1D
(
in_channels
,
out_channels
,
embed_dim
=
embed_dim
)]
for
_
in
range
(
num_layers
):
resnets
.
append
(
ResidualTemporalBlock1D
(
out_channels
,
out_channels
,
embed_dim
=
embed_dim
))
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
non_linearity
==
"swish"
:
self
.
nonlinearity
=
lambda
x
:
F
.
silu
(
x
)
elif
non_linearity
==
"mish"
:
self
.
nonlinearity
=
nn
.
Mish
()
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
else
:
self
.
nonlinearity
=
None
self
.
upsample
=
None
if
add_upsample
:
self
.
upsample
=
Downsample1D
(
out_channels
,
use_conv
=
True
)
self
.
downsample
=
None
if
add_downsample
:
self
.
downsample
=
Downsample1D
(
out_channels
,
use_conv
=
True
)
if
self
.
upsample
and
self
.
downsample
:
raise
ValueError
(
"Block cannot downsample and upsample"
)
def
forward
(
self
,
hidden_states
,
temb
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
resnet
in
self
.
resnets
[
1
:]:
hidden_states
=
resnet
(
hidden_states
,
temb
)
if
self
.
upsample
:
hidden_states
=
self
.
upsample
(
hidden_states
)
if
self
.
downsample
:
self
.
downsample
=
self
.
downsample
(
hidden_states
)
return
hidden_states
class
OutConv1DBlock
(
nn
.
Module
):
def
__init__
(
self
,
num_groups_out
,
out_channels
,
embed_dim
,
act_fn
):
super
().
__init__
()
self
.
final_conv1d_1
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
5
,
padding
=
2
)
self
.
final_conv1d_gn
=
nn
.
GroupNorm
(
num_groups_out
,
embed_dim
)
if
act_fn
==
"silu"
:
self
.
final_conv1d_act
=
nn
.
SiLU
()
if
act_fn
==
"mish"
:
self
.
final_conv1d_act
=
nn
.
Mish
()
self
.
final_conv1d_2
=
nn
.
Conv1d
(
embed_dim
,
out_channels
,
1
)
def
forward
(
self
,
hidden_states
,
temb
=
None
):
hidden_states
=
self
.
final_conv1d_1
(
hidden_states
)
hidden_states
=
rearrange_dims
(
hidden_states
)
hidden_states
=
self
.
final_conv1d_gn
(
hidden_states
)
hidden_states
=
rearrange_dims
(
hidden_states
)
hidden_states
=
self
.
final_conv1d_act
(
hidden_states
)
hidden_states
=
self
.
final_conv1d_2
(
hidden_states
)
return
hidden_states
class
OutValueFunctionBlock
(
nn
.
Module
):
def
__init__
(
self
,
fc_dim
,
embed_dim
):
super
().
__init__
()
self
.
final_block
=
nn
.
ModuleList
(
[
nn
.
Linear
(
fc_dim
+
embed_dim
,
fc_dim
//
2
),
nn
.
Mish
(),
nn
.
Linear
(
fc_dim
//
2
,
1
),
]
)
def
forward
(
self
,
hidden_states
,
temb
):
hidden_states
=
hidden_states
.
view
(
hidden_states
.
shape
[
0
],
-
1
)
hidden_states
=
torch
.
cat
((
hidden_states
,
temb
),
dim
=-
1
)
for
layer
in
self
.
final_block
:
hidden_states
=
layer
(
hidden_states
)
return
hidden_states
_kernels
=
{
_kernels
=
{
"linear"
:
[
1
/
8
,
3
/
8
,
3
/
8
,
1
/
8
],
"linear"
:
[
1
/
8
,
3
/
8
,
3
/
8
,
1
/
8
],
...
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
...
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
self
.
pad
=
kernel_1d
.
shape
[
0
]
//
2
-
1
self
.
pad
=
kernel_1d
.
shape
[
0
]
//
2
-
1
self
.
register_buffer
(
"kernel"
,
kernel_1d
)
self
.
register_buffer
(
"kernel"
,
kernel_1d
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
,
temb
=
None
):
hidden_states
=
F
.
pad
(
hidden_states
,
((
self
.
pad
+
1
)
//
2
,)
*
2
,
self
.
pad_mode
)
hidden_states
=
F
.
pad
(
hidden_states
,
((
self
.
pad
+
1
)
//
2
,)
*
2
,
self
.
pad_mode
)
weight
=
hidden_states
.
new_zeros
([
hidden_states
.
shape
[
1
],
hidden_states
.
shape
[
1
],
self
.
kernel
.
shape
[
0
]])
weight
=
hidden_states
.
new_zeros
([
hidden_states
.
shape
[
1
],
hidden_states
.
shape
[
1
],
self
.
kernel
.
shape
[
0
]])
indices
=
torch
.
arange
(
hidden_states
.
shape
[
1
],
device
=
hidden_states
.
device
)
indices
=
torch
.
arange
(
hidden_states
.
shape
[
1
],
device
=
hidden_states
.
device
)
...
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
...
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
return
output
return
output
def
get_down_block
(
down_block_type
,
out_channels
,
in_channels
):
if
down_block_type
==
"DownBlock1D"
:
return
DownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"AttnDownBlock1D"
:
return
AttnDownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"DownBlock1DNoSkip"
:
return
DownBlock1DNoSkip
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
raise
ValueError
(
f
"
{
down_block_type
}
does not exist."
)
def
get_up_block
(
up_block_type
,
in_channels
,
out_channels
):
if
up_block_type
==
"UpBlock1D"
:
return
UpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"AttnUpBlock1D"
:
return
AttnUpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"UpBlock1DNoSkip"
:
return
UpBlock1DNoSkip
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
up_block_type
}
does not exist."
)
def
get_mid_block
(
mid_block_type
,
in_channels
,
mid_channels
,
out_channels
):
if
mid_block_type
==
"UNetMidBlock1D"
:
return
UNetMidBlock1D
(
in_channels
=
in_channels
,
mid_channels
=
mid_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
mid_block_type
}
does not exist."
)
class
UNetMidBlock1D
(
nn
.
Module
):
class
UNetMidBlock1D
(
nn
.
Module
):
def
__init__
(
self
,
mid_channels
,
in_channels
,
out_channels
=
None
):
def
__init__
(
self
,
mid_channels
,
in_channels
,
out_channels
=
None
):
super
().
__init__
()
super
().
__init__
()
...
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
...
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
,
temb
=
None
):
hidden_states
=
self
.
down
(
hidden_states
)
hidden_states
=
self
.
down
(
hidden_states
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
):
hidden_states
=
resnet
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
)
...
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
...
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
):
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
...
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
...
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
self
.
up
=
Upsample1d
(
kernel
=
"cubic"
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
):
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
...
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
...
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
):
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
...
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
...
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
hidden_states
=
resnet
(
hidden_states
)
hidden_states
=
resnet
(
hidden_states
)
return
hidden_states
return
hidden_states
def
get_down_block
(
down_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_downsample
):
if
down_block_type
==
"DownResnetBlock1D"
:
return
DownResnetBlock1D
(
in_channels
=
in_channels
,
num_layers
=
num_layers
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
)
elif
down_block_type
==
"DownBlock1D"
:
return
DownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"AttnDownBlock1D"
:
return
AttnDownBlock1D
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
elif
down_block_type
==
"DownBlock1DNoSkip"
:
return
DownBlock1DNoSkip
(
out_channels
=
out_channels
,
in_channels
=
in_channels
)
raise
ValueError
(
f
"
{
down_block_type
}
does not exist."
)
def
get_up_block
(
up_block_type
,
num_layers
,
in_channels
,
out_channels
,
temb_channels
,
add_upsample
):
if
up_block_type
==
"UpResnetBlock1D"
:
return
UpResnetBlock1D
(
in_channels
=
in_channels
,
num_layers
=
num_layers
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
)
elif
up_block_type
==
"UpBlock1D"
:
return
UpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"AttnUpBlock1D"
:
return
AttnUpBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
elif
up_block_type
==
"UpBlock1DNoSkip"
:
return
UpBlock1DNoSkip
(
in_channels
=
in_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
up_block_type
}
does not exist."
)
def
get_mid_block
(
mid_block_type
,
num_layers
,
in_channels
,
mid_channels
,
out_channels
,
embed_dim
,
add_downsample
):
if
mid_block_type
==
"MidResTemporalBlock1D"
:
return
MidResTemporalBlock1D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
embed_dim
=
embed_dim
,
add_downsample
=
add_downsample
,
)
elif
mid_block_type
==
"ValueFunctionMidBlock1D"
:
return
ValueFunctionMidBlock1D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
embed_dim
=
embed_dim
)
elif
mid_block_type
==
"UNetMidBlock1D"
:
return
UNetMidBlock1D
(
in_channels
=
in_channels
,
mid_channels
=
mid_channels
,
out_channels
=
out_channels
)
raise
ValueError
(
f
"
{
mid_block_type
}
does not exist."
)
def
get_out_block
(
*
,
out_block_type
,
num_groups_out
,
embed_dim
,
out_channels
,
act_fn
,
fc_dim
):
if
out_block_type
==
"OutConv1DBlock"
:
return
OutConv1DBlock
(
num_groups_out
,
out_channels
,
embed_dim
,
act_fn
)
elif
out_block_type
==
"ValueFunction"
:
return
OutValueFunctionBlock
(
fc_dim
,
embed_dim
)
return
None
src/diffusers/models/unet_2d.py
View file @
554b374d
...
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
...
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`
Fals
e`): Whether to flip sin to cos for fourier time embedding.
obj:`
Tru
e`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
types.
...
...
src/diffusers/models/unet_2d_condition.py
View file @
554b374d
...
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
...
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `
Fals
e`):
flip_sin_to_cos (`bool`, *optional*, defaults to `
Tru
e`):
Whether to flip the sin to cos in the time embedding.
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
...
...
src/diffusers/pipeline_flax_utils.py
View file @
554b374d
...
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
...
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"FlaxModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxSchedulerMixin"
:
[
"save_
config"
,
"from_config
"
],
"FlaxSchedulerMixin"
:
[
"save_
pretrained"
,
"from_pretrained
"
],
"FlaxDiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxDiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
"transformers"
:
{
"transformers"
:
{
...
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_
config
(
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_
pretrained
(
... model_id,
... model_id,
... subfolder="scheduler",
... subfolder="scheduler",
... )
... )
...
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
config_dict
=
cls
.
get
_config
_dict
(
config_dict
=
cls
.
load
_config
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get
_config
_dict
(
cached_folder
)
config_dict
=
cls
.
load
_config
(
cached_folder
)
# 2. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
...
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
init_dict
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
_
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_kwargs
=
{}
init_kwargs
=
{}
...
...
src/diffusers/pipeline_utils.py
View file @
554b374d
...
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
...
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"SchedulerMixin"
:
[
"save_
config"
,
"from_config
"
],
"SchedulerMixin"
:
[
"save_
pretrained"
,
"from_pretrained
"
],
"DiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"DiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"OnnxRuntimeModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"OnnxRuntimeModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
...
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
if
torch_device
is
None
:
if
torch_device
is
None
:
return
self
return
self
module_names
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
module_names
,
_
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
for
name
in
module_names
.
keys
():
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
...
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
Returns:
Returns:
`torch.device`: The torch device on which the pipeline is located.
`torch.device`: The torch device on which the pipeline is located.
"""
"""
module_names
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
module_names
,
_
,
_
=
self
.
extract_init_dict
(
dict
(
self
.
config
))
for
name
in
module_names
.
keys
():
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
...
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> #
Download pipeline, but overwrite
scheduler
>>> #
Use a different
scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config(
"runwayml/stable-diffusion-v1-5", subfolder="
scheduler
"
)
>>> scheduler = LMSDiscreteScheduler.from_config(
pipeline.
scheduler
.config
)
>>> pipeline
= DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
scheduler
=
scheduler
)
>>> pipeline
.
scheduler
=
scheduler
```
```
"""
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
...
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
config_dict
=
cls
.
get
_config
_dict
(
config_dict
=
cls
.
load
_config
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get
_config
_dict
(
cached_folder
)
config_dict
=
cls
.
load
_config
(
cached_folder
)
# 2. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
...
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
-
set
([
"self"
])
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
())
-
set
([
"self"
])
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
init_dict
,
unused_kwargs
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
unused_kwargs
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
if
len
(
unused_kwargs
)
>
0
:
if
len
(
unused_kwargs
)
>
0
:
logger
.
warning
(
f
"Keyword arguments
{
unused_kwargs
}
not recognized."
)
logger
.
warning
(
f
"Keyword arguments
{
unused_kwargs
}
not recognized."
)
...
...
src/diffusers/pipelines/README.md
View file @
554b374d
...
@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
...
@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
|
[
pndm
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm
)
|
[
**Pseudo Numerical Methods for Diffusion Models on Manifolds**
](
https://arxiv.org/abs/2202.09778
)
|
*Unconditional Image Generation*
|
|
[
pndm
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm
)
|
[
**Pseudo Numerical Methods for Diffusion Models on Manifolds**
](
https://arxiv.org/abs/2202.09778
)
|
*Unconditional Image Generation*
|
|
[
score_sde_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
score_sde_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
score_sde_vp
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
score_sde_vp
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp
)
|
[
**Score-Based Generative Modeling through Stochastic Differential Equations**
](
https://openreview.net/forum?id=PxTIG12RRHS
)
|
*Unconditional Image Generation*
|
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-to-Image Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/
training_example
.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-to-Image Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/
stable_diffusion
.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Image-to-Image Text-Guided Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Image-to-Image Text-Guided Generation*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-Guided Image Inpainting*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
[
stable_diffusion
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion
)
|
[
**Stable Diffusion**
](
https://stability.ai/blog/stable-diffusion-public-release
)
|
*Text-Guided Image Inpainting*
|
[

](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
[
stochastic_karras_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve
)
|
[
**Elucidating the Design Space of Diffusion-Based Generative Models**
](
https://arxiv.org/abs/2206.00364
)
|
*Unconditional Image Generation*
|
|
[
stochastic_karras_ve
](
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve
)
|
[
**Elucidating the Design Space of Diffusion-Based Generative Models**
](
https://arxiv.org/abs/2206.00364
)
|
*Unconditional Image Generation*
|
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
554b374d
...
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
"""
"""
message
=
(
message
=
(
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_
config
(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_
pretrained
(<model_id>, predict_epsilon=True)`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
...
...
src/diffusers/pipelines/stable_diffusion/README.md
View file @
554b374d
...
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
...
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
# make sure you're logged in with `huggingface-cli login`
from
diffusers
import
StableDiffusionPipeline
,
DDIMScheduler
from
diffusers
import
StableDiffusionPipeline
,
DDIMScheduler
scheduler
=
DDIMScheduler
.
from_
config
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
"runwayml/stable-diffusion-v1-5"
,
...
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
...
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
# make sure you're logged in with `huggingface-cli login`
from
diffusers
import
StableDiffusionPipeline
,
LMSDiscreteScheduler
from
diffusers
import
StableDiffusionPipeline
,
LMSDiscreteScheduler
lms
=
LMSDiscreteScheduler
.
from_
config
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_
pretrained
(
"CompVis/stable-diffusion-v1-4"
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
"runwayml/stable-diffusion-v1-5"
,
...
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
...
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
# make sure you're logged in with `huggingface-cli login`
model_id_or_path
=
"CompVis/stable-diffusion-v1-4"
model_id_or_path
=
"CompVis/stable-diffusion-v1-4"
scheduler
=
DDIMScheduler
.
from_
config
(
model_id_or_path
,
subfolder
=
"scheduler"
)
scheduler
=
DDIMScheduler
.
from_
pretrained
(
model_id_or_path
,
subfolder
=
"scheduler"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id_or_path
,
scheduler
=
scheduler
).
to
(
"cuda"
)
pipe
=
CycleDiffusionPipeline
.
from_pretrained
(
model_id_or_path
,
scheduler
=
scheduler
).
to
(
"cuda"
)
# let's download an initial image
# let's download an initial image
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
554b374d
...
@@ -23,7 +23,7 @@ import numpy as np
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
For more details, see the original paper: https://arxiv.org/abs/2010.02502
...
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"PNDMScheduler"
,
"DDPMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
554b374d
...
@@ -23,7 +23,12 @@ import flax
...
@@ -23,7 +23,12 @@ import flax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
For more details, see the original paper: https://arxiv.org/abs/2010.02502
...
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
stable diffusion.
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
554b374d
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
BaseOutput
,
deprecate
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
For more details, see the original paper: https://arxiv.org/abs/2006.11239
...
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"PNDMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
"DPMSolverMultistepScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# for rl-diffuser https://arxiv.org/abs/2205.09991
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
variance_type
==
"fixed_small_log"
:
elif
variance_type
==
"fixed_small_log"
:
variance
=
torch
.
log
(
torch
.
clamp
(
variance
,
min
=
1e-20
))
variance
=
torch
.
log
(
torch
.
clamp
(
variance
,
min
=
1e-20
))
variance
=
torch
.
exp
(
0.5
*
variance
)
elif
variance_type
==
"fixed_large"
:
elif
variance_type
==
"fixed_large"
:
variance
=
self
.
betas
[
t
]
variance
=
self
.
betas
[
t
]
elif
variance_type
==
"fixed_large_log"
:
elif
variance_type
==
"fixed_large_log"
:
...
@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
message
=
(
message
=
(
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_
config
(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_
pretrained
(<model_id>, predict_epsilon=True)`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
...
@@ -301,6 +295,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -301,6 +295,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_noise
=
torch
.
randn
(
variance_noise
=
torch
.
randn
(
model_output
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
model_output
.
dtype
model_output
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
model_output
.
dtype
)
)
if
self
.
variance_type
==
"fixed_small_log"
:
variance
=
self
.
_get_variance
(
t
,
predicted_variance
=
predicted_variance
)
*
variance_noise
else
:
variance
=
(
self
.
_get_variance
(
t
,
predicted_variance
=
predicted_variance
)
**
0.5
)
*
variance_noise
variance
=
(
self
.
_get_variance
(
t
,
predicted_variance
=
predicted_variance
)
**
0.5
)
*
variance_noise
pred_prev_sample
=
pred_prev_sample
+
variance
pred_prev_sample
=
pred_prev_sample
+
variance
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
554b374d
...
@@ -24,7 +24,12 @@ from jax import random
...
@@ -24,7 +24,12 @@ from jax import random
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..utils
import
deprecate
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
For more details, see the original paper: https://arxiv.org/abs/2006.11239
...
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
message
=
(
message
=
(
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_
config
(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_
pretrained
(<model_id>, predict_epsilon=True)`."
)
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.10.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
if
predict_epsilon
is
not
None
and
predict_epsilon
!=
self
.
config
.
predict_epsilon
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
554b374d
...
@@ -21,6 +21,7 @@ import numpy as np
...
@@ -21,6 +21,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
from
.scheduling_utils
import
SchedulerMixin
,
SchedulerOutput
...
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
Args:
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
num_train_timesteps (`int`): number of diffusion steps used to train the model.
...
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatible_classes
=
[
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
"DDIMScheduler"
,
"DDPMScheduler"
,
"PNDMScheduler"
,
"LMSDiscreteScheduler"
,
"EulerDiscreteScheduler"
,
"EulerAncestralDiscreteScheduler"
,
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
554b374d
...
@@ -23,7 +23,12 @@ import jax
...
@@ -23,7 +23,12 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
from
.scheduling_utils_flax
import
(
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
,
FlaxSchedulerMixin
,
FlaxSchedulerOutput
,
broadcast_to_shape_from_left
,
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
def
betas_for_alpha_bar
(
num_diffusion_timesteps
:
int
,
max_beta
=
0.999
)
->
jnp
.
ndarray
:
...
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`
~Config
Mixin`]
also
provides general loading and saving functionality via the [`
~Config
Mixin.save_
config
`] and
[`
Scheduler
Mixin`] provides general loading and saving functionality via the [`
Scheduler
Mixin.save_
pretrained
`] and
[`~
Config
Mixin.from_
config
`] functions.
[`~
Scheduler
Mixin.from_
pretrained
`] functions.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
...
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
return
True
return
True
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment