Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b8be488f
Commit
b8be488f
authored
Jun 07, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
ce566621
f9a4532f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
783 additions
and
5 deletions
+783
-5
models/vision/ddpm/test.png
models/vision/ddpm/test.png
+0
-0
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+59
-0
models/vision/glide/run_dalle2.py
models/vision/glide/run_dalle2.py
+0
-0
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+16
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+665
-0
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+14
-5
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+27
-0
No files found.
models/vision/ddpm/test.png
deleted
100644 → 0
View file @
ce566621
102 KB
models/vision/glide/modeling_glide.py
View file @
b8be488f
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
diffusers
import
DiffusionPipeline
from
diffusers
import
UNetGLIDEModel
import
tqdm
import
torch
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
):
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
generator
=
None
,
torch_device
=
None
):
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
# 1. Sample gaussian noise
image
=
self
.
noise_scheduler
.
sample_noise
((
1
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
self
.
unet
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
return
image
models/vision/glide/run_dalle2.py
deleted
100644 → 0
View file @
ce566621
models/vision/glide/run_glide.py
0 → 100644
View file @
b8be488f
import
torch
from
.modeling_glide
import
GLIDE
from
diffusers
import
UNetGLIDEModel
,
GaussianDDPMScheduler
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
(
model
,
scheduler
)
img
=
pipeline
(
generator
)
print
(
img
)
src/diffusers/__init__.py
View file @
b8be488f
...
@@ -6,5 +6,6 @@ __version__ = "0.0.1"
...
@@ -6,5 +6,6 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/models/__init__.py
View file @
b8be488f
...
@@ -17,3 +17,4 @@
...
@@ -17,3 +17,4 @@
# limitations under the License.
# limitations under the License.
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNetGLIDEModel
src/diffusers/models/unet_glide.py
0 → 100644
View file @
b8be488f
import
math
from
abc
import
abstractmethod
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
def
convert_module_to_f16
(
l
):
"""
Convert primitive modules to float16.
"""
if
isinstance
(
l
,
(
nn
.
Conv1d
,
nn
.
Conv2d
,
nn
.
Conv3d
)):
l
.
weight
.
data
=
l
.
weight
.
data
.
half
()
if
l
.
bias
is
not
None
:
l
.
bias
.
data
=
l
.
bias
.
data
.
half
()
def
convert_module_to_f32
(
l
):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if
isinstance
(
l
,
(
nn
.
Conv1d
,
nn
.
Conv2d
,
nn
.
Conv3d
)):
l
.
weight
.
data
=
l
.
weight
.
data
.
float
()
if
l
.
bias
is
not
None
:
l
.
bias
.
data
=
l
.
bias
.
data
.
float
()
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
if
self
.
swish
==
1.0
:
y
=
F
.
silu
(
y
)
elif
self
.
swish
:
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swish
))
return
y
def
normalization
(
channels
,
swish
=
0.0
):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
([
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
)
else
:
x
=
layer
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
1
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
).
type
(
h
.
dtype
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
torch
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
if
encoder_out
is
not
None
:
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
h
=
self
.
attention
(
qkv
,
encoder_out
)
else
:
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
encoder_kv
=
None
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
torch
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
torch
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
UNetGLIDEModel
(
ModelMixin
,
ConfigMixin
):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
"""
def
__init__
(
self
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
encoder_channels
=
None
,
):
super
().
__init__
()
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
encoder_channels
=
encoder_channels
,
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
_feature_size
=
ch
input_block_chans
=
[
ch
]
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
int
(
mult
*
model_channels
)
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
)
)
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
,
swish
=
1.0
),
nn
.
Identity
(),
zero_module
(
conv_nd
(
dims
,
input_ch
,
out_channels
,
3
,
padding
=
1
)),
)
self
.
use_fp16
=
use_fp16
def
convert_to_fp16
(
self
):
"""
Convert the torso of the model to float16.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f16
)
self
.
middle_block
.
apply
(
convert_module_to_f16
)
self
.
output_blocks
.
apply
(
convert_module_to_f16
)
def
convert_to_fp32
(
self
):
"""
Convert the torso of the model to float32.
"""
self
.
input_blocks
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
,
transformer_out
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
# project the last token
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_out
=
transformer_out
.
permute
(
0
,
2
,
1
)
# NLC -> NCL
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
src/diffusers/pipeline_utils.py
View file @
b8be488f
...
@@ -92,11 +92,17 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -92,11 +92,17 @@ class DiffusionPipeline(ConfigMixin):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
_
,
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
else
:
cached_folder
=
pretrained_model_name_or_path
module
=
config_dict
.
pop
(
"_module"
,
None
)
config_dict
,
pipeline_kwargs
=
cls
.
get_config_dict
(
cached_folder
)
class_name_
=
config_dict
.
pop
(
"_class_name"
)
module
=
pipeline_kwargs
.
pop
(
"_module"
,
None
)
# TODO(Suraj) - make from hub import work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
init_kwargs
=
{}
init_kwargs
=
{}
...
@@ -118,7 +124,10 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -118,7 +124,10 @@ class DiffusionPipeline(ConfigMixin):
load_method
=
getattr
(
class_obj
,
load_method_name
)
load_method
=
getattr
(
class_obj
,
load_method_name
)
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
else
:
loaded_sub_model
=
load_method
(
cached_folder
)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
...
...
src/diffusers/schedulers/gaussian_ddpm.py
View file @
b8be488f
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
math
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
...
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
config_name
=
SAMPLING_CONFIG_NAME
...
@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment