Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
7f6a36c3
Unverified
Commit
7f6a36c3
authored
Jun 07, 2022
by
Anton Lozhkov
Committed by
GitHub
Jun 07, 2022
Browse files
Merge pull request #2 from huggingface/add-glide
+ cosine schedule and unet config
parents
2db090de
747f42d0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
13 deletions
+78
-13
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+16
-0
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+35
-13
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+27
-0
No files found.
models/vision/glide/run_glide.py
View file @
7f6a36c3
import
torch
from
.modeling_glide
import
GLIDE
from
diffusers
import
UNetGLIDEModel
,
GaussianDDPMScheduler
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
(
model
,
scheduler
)
img
=
pipeline
(
generator
)
print
(
img
)
src/diffusers/models/unet_glide.py
View file @
7f6a36c3
import
math
import
math
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
torch
as
th
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
..configuration_utils
import
Config
from
..modeling_utils
import
PreTrainedModel
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
"""
"""
...
@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
...
@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
:return: an [N x dim] Tensor of positional embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
"""
half
=
dim
//
2
half
=
dim
//
2
freqs
=
th
.
exp
(
-
math
.
log
(
max_period
)
*
th
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
th
.
float32
)
/
half
).
to
(
freqs
=
t
orc
h
.
exp
(
-
math
.
log
(
max_period
)
*
t
orc
h
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
t
orc
h
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
device
=
timesteps
.
device
)
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
th
.
cat
([
th
.
cos
(
args
),
th
.
sin
(
args
)],
dim
=-
1
)
embedding
=
t
orc
h
.
cat
([
t
orc
h
.
cos
(
args
),
t
orc
h
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
if
dim
%
2
:
embedding
=
th
.
cat
([
embedding
,
th
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
embedding
=
t
orc
h
.
cat
([
embedding
,
t
orc
h
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
return
embedding
...
@@ -298,7 +301,7 @@ class ResBlock(TimestepBlock):
...
@@ -298,7 +301,7 @@ class ResBlock(TimestepBlock):
emb_out
=
emb_out
[...,
None
]
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
scale
,
shift
=
t
orc
h
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
h
=
out_rest
(
h
)
else
:
else
:
...
@@ -376,16 +379,16 @@ class QKVAttention(nn.Module):
...
@@ -376,16 +379,16 @@ class QKVAttention(nn.Module):
if
encoder_kv
is
not
None
:
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
th
.
cat
([
ek
,
k
],
dim
=-
1
)
k
=
t
orc
h
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
th
.
cat
([
ev
,
v
],
dim
=-
1
)
v
=
t
orc
h
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
t
orc
h
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
th
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
weight
=
t
orc
h
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
a
=
t
orc
h
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
UNetGLIDEModel
(
nn
.
Module
):
class
UNetGLIDEModel
(
PreTrainedModel
,
Config
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
...
@@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module):
...
@@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module):
encoder_channels
=
None
,
encoder_channels
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
encoder_channels
=
encoder_channels
,
)
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
num_heads_upsample
=
num_heads
...
@@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module):
...
@@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module):
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
dtype
=
t
orc
h
.
float16
if
use_fp16
else
t
orc
h
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
num_heads_upsample
=
num_heads_upsample
...
@@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module):
...
@@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module):
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
t
orc
h
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
src/diffusers/schedulers/gaussian_ddpm.py
View file @
7f6a36c3
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
math
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
...
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
config_name
=
SAMPLING_CONFIG_NAME
...
@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment