Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7f6a36c3
Unverified
Commit
7f6a36c3
authored
Jun 07, 2022
by
Anton Lozhkov
Committed by
GitHub
Jun 07, 2022
Browse files
Merge pull request #2 from huggingface/add-glide
+ cosine schedule and unet config
parents
2db090de
747f42d0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
13 deletions
+78
-13
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+16
-0
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+35
-13
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+27
-0
No files found.
models/vision/glide/run_glide.py
View file @
7f6a36c3
import
torch
from
.modeling_glide
import
GLIDE
from
diffusers
import
UNetGLIDEModel
,
GaussianDDPMScheduler
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
(
model
,
scheduler
)
img
=
pipeline
(
generator
)
print
(
img
)
src/diffusers/models/unet_glide.py
View file @
7f6a36c3
import
math
import
math
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
torch
as
th
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
..configuration_utils
import
Config
from
..modeling_utils
import
PreTrainedModel
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
"""
"""
...
@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
...
@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
:return: an [N x dim] Tensor of positional embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
"""
half
=
dim
//
2
half
=
dim
//
2
freqs
=
th
.
exp
(
-
math
.
log
(
max_period
)
*
th
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
th
.
float32
)
/
half
).
to
(
freqs
=
t
orc
h
.
exp
(
-
math
.
log
(
max_period
)
*
t
orc
h
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
t
orc
h
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
device
=
timesteps
.
device
)
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
th
.
cat
([
th
.
cos
(
args
),
th
.
sin
(
args
)],
dim
=-
1
)
embedding
=
t
orc
h
.
cat
([
t
orc
h
.
cos
(
args
),
t
orc
h
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
if
dim
%
2
:
embedding
=
th
.
cat
([
embedding
,
th
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
embedding
=
t
orc
h
.
cat
([
embedding
,
t
orc
h
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
return
embedding
...
@@ -298,7 +301,7 @@ class ResBlock(TimestepBlock):
...
@@ -298,7 +301,7 @@ class ResBlock(TimestepBlock):
emb_out
=
emb_out
[...,
None
]
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
scale
,
shift
=
t
orc
h
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
h
=
out_rest
(
h
)
else
:
else
:
...
@@ -376,16 +379,16 @@ class QKVAttention(nn.Module):
...
@@ -376,16 +379,16 @@ class QKVAttention(nn.Module):
if
encoder_kv
is
not
None
:
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
th
.
cat
([
ek
,
k
],
dim
=-
1
)
k
=
t
orc
h
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
th
.
cat
([
ev
,
v
],
dim
=-
1
)
v
=
t
orc
h
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
t
orc
h
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
weight
=
t
orc
h
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
a
=
t
orc
h
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
UNetGLIDEModel
(
nn
.
Module
):
class
UNetGLIDEModel
(
PreTrainedModel
,
Config
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
...
@@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module):
...
@@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module):
encoder_channels
=
None
,
encoder_channels
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
encoder_channels
=
encoder_channels
,
)
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
num_heads_upsample
=
num_heads
...
@@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module):
...
@@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module):
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
dtype
=
t
orc
h
.
float16
if
use_fp16
else
t
orc
h
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
num_heads_upsample
=
num_heads_upsample
...
@@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module):
...
@@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module):
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
t
orc
h
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
src/diffusers/schedulers/gaussian_ddpm.py
View file @
7f6a36c3
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
math
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
...
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
config_name
=
SAMPLING_CONFIG_NAME
...
@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment