Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
528b1293
Commit
528b1293
authored
Jun 09, 2022
by
anton-l
Browse files
make style
parents
f23bb3e8
cbb19ee8
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
321 additions
and
303 deletions
+321
-303
models/vision/ddim/example.py
models/vision/ddim/example.py
+5
-2
models/vision/ddim/modeling_ddim.py
models/vision/ddim/modeling_ddim.py
+17
-7
models/vision/ddim/run_inference.py
models/vision/ddim/run_inference.py
+4
-2
models/vision/ddpm/example.py
models/vision/ddpm/example.py
+17
-3
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+21
-7
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+17
-7
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+41
-25
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+5
-3
setup.py
setup.py
+2
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+19
-24
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+5
-19
src/diffusers/dynamic_modules_utils.py
src/diffusers/dynamic_modules_utils.py
+2
-1
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+6
-3
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+2
-2
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+13
-10
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+38
-38
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+101
-144
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+2
-3
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+3
-2
No files found.
models/vision/ddim/example.py
View file @
528b1293
#!/usr/bin/env python3
#!/usr/bin/env python3
import
os
import
os
import
pathlib
import
pathlib
from
modeling_ddim
import
DDIM
import
PIL.Image
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
from
modeling_ddim
import
DDIM
model_ids
=
[
"ddim-celeba-hq"
,
"ddim-lsun-church"
,
"ddim-lsun-bedroom"
]
model_ids
=
[
"ddim-celeba-hq"
,
"ddim-lsun-church"
,
"ddim-lsun-bedroom"
]
for
model_id
in
model_ids
:
for
model_id
in
model_ids
:
...
...
models/vision/ddim/modeling_ddim.py
View file @
528b1293
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
# limitations under the License.
# limitations under the License.
from
diffusers
import
DiffusionPipeline
import
tqdm
import
torch
import
torch
import
tqdm
from
diffusers
import
DiffusionPipeline
class
DDIM
(
DiffusionPipeline
):
class
DDIM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
...
@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
image
=
self
.
noise_scheduler
.
sample_noise
((
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
image
=
self
.
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
,
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# get actual t and t-1
# get actual t and t-1
train_step
=
inference_step_times
[
t
]
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
# compute alphas
# compute alphas
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
...
@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
...
@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
# compute relevant coefficients
# compute relevant coefficients
coeff_1
=
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
coeff_1
=
(
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
)
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
# model forward
# model forward
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
models/vision/ddim/run_inference.py
View file @
528b1293
#!/usr/bin/env python3
#!/usr/bin/env python3
# !pip install diffusers
# !pip install diffusers
from
modeling_ddim
import
DDIM
import
PIL.Image
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
from
modeling_ddim
import
DDIM
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-lsun-bedroom"
model_id
=
"fusing/ddpm-lsun-bedroom"
...
...
models/vision/ddpm/example.py
View file @
528b1293
#!/usr/bin/env python3
#!/usr/bin/env python3
import
os
import
os
import
pathlib
import
pathlib
from
modeling_ddpm
import
DDPM
import
PIL.Image
import
numpy
as
np
import
numpy
as
np
model_ids
=
[
"ddpm-lsun-cat"
,
"ddpm-lsun-cat-ema"
,
"ddpm-lsun-church-ema"
,
"ddpm-lsun-church"
,
"ddpm-lsun-bedroom"
,
"ddpm-lsun-bedroom-ema"
,
"ddpm-cifar10-ema"
,
"ddpm-cifar10"
,
"ddpm-celeba-hq"
,
"ddpm-celeba-hq-ema"
]
import
PIL.Image
from
modeling_ddpm
import
DDPM
model_ids
=
[
"ddpm-lsun-cat"
,
"ddpm-lsun-cat-ema"
,
"ddpm-lsun-church-ema"
,
"ddpm-lsun-church"
,
"ddpm-lsun-bedroom"
,
"ddpm-lsun-bedroom-ema"
,
"ddpm-cifar10-ema"
,
"ddpm-cifar10"
,
"ddpm-celeba-hq"
,
"ddpm-celeba-hq-ema"
,
]
for
model_id
in
model_ids
:
for
model_id
in
model_ids
:
path
=
os
.
path
.
join
(
"/home/patrick/images/hf"
,
model_id
)
path
=
os
.
path
.
join
(
"/home/patrick/images/hf"
,
model_id
)
...
...
models/vision/ddpm/modeling_ddpm.py
View file @
528b1293
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
# limitations under the License.
# limitations under the License.
from
diffusers
import
DiffusionPipeline
import
tqdm
import
torch
import
torch
import
tqdm
from
diffusers
import
DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
...
@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# 1. Sample gaussian noise
# 1. Sample gaussian noise
image
=
self
.
noise_scheduler
.
sample_noise
((
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
image
=
self
.
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
,
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
# i) define coefficients for time step t
# i) define coefficients for time step t
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
image_coeff
=
(
clipped_coeff
=
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
)
clipped_coeff
=
(
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
)
# ii) predict noise residual
# ii) predict noise residual
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
...
@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
# iv) sample variance
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
sampled_prev_image
=
prev_image
+
prev_variance
...
...
models/vision/glide/convert_weights.py
View file @
528b1293
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
diffusers
import
(
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
)
from
modeling_glide
import
GLIDE
,
CLIPTextModel
from
modeling_glide
import
GLIDE
,
CLIPTextModel
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
...
@@ -22,7 +27,9 @@ config = CLIPTextConfig(
...
@@ -22,7 +27,9 @@ config = CLIPTextConfig(
use_padding_embeddings
=
True
,
use_padding_embeddings
=
True
,
)
)
model
=
CLIPTextModel
(
config
).
eval
()
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/tokenizer/vocab.json"
,
"./glide-base/tokenizer/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
=
GPT2Tokenizer
(
"./glide-base/tokenizer/vocab.json"
,
"./glide-base/tokenizer/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
hf_encoder
=
model
.
text_model
hf_encoder
=
model
.
text_model
...
@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
...
@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler
=
GlideDDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
)
upscale_scheduler
=
GlideDDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
)
glide
=
GLIDE
(
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
glide
=
GLIDE
(
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
upscale_scheduler
)
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
upscale_scheduler
,
)
glide
.
save_pretrained
(
"./glide-base"
)
glide
.
save_pretrained
(
"./glide-base"
)
models/vision/glide/modeling_glide.py
View file @
528b1293
...
@@ -18,10 +18,20 @@ import math
...
@@ -18,10 +18,20 @@ import math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
import
tqdm
from
diffusers
import
(
ClassifierFreeGuidanceScheduler
,
DiffusionPipeline
,
GlideDDIMScheduler
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
)
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
,
GPT2Tokenizer
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
...
@@ -34,14 +44,6 @@ from transformers.utils import (
...
@@ -34,14 +44,6 @@ from transformers.utils import (
)
)
import
numpy
as
np
import
torch
import
tqdm
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
DiffusionPipeline
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
transformers
import
GPT2Tokenizer
#####################
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
#####################
#####################
...
@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
...
@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_noise_scheduler
:
GlideDDIMScheduler
upscale_noise_scheduler
:
GlideDDIMScheduler
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
self
.
register_modules
(
text_unet
=
text_unet
,
text_noise_scheduler
=
text_noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
text_unet
=
text_unet
,
upscale_unet
=
upscale_unet
,
upscale_noise_scheduler
=
upscale_noise_scheduler
text_noise_scheduler
=
text_noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_noise_scheduler
=
upscale_noise_scheduler
,
)
)
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
...
@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
posterior_log_variance_clipped
=
_extract_into_tensor
(
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
assert
(
posterior_mean
.
shape
[
0
]
posterior_mean
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
...
@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
...
@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp
=
0.997
upsample_temp
=
0.997
image
=
self
.
upscale_noise_scheduler
.
sample_noise
(
image
=
(
self
.
upscale_noise_scheduler
.
sample_noise
(
(
batch_size
,
3
,
256
,
256
),
device
=
torch_device
,
generator
=
generator
(
batch_size
,
3
,
256
,
256
),
device
=
torch_device
,
generator
=
generator
)
*
upsample_temp
)
*
upsample_temp
)
num_timesteps
=
len
(
self
.
upscale_noise_scheduler
)
num_timesteps
=
len
(
self
.
upscale_noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
upscale_noise_scheduler
))),
total
=
len
(
self
.
upscale_noise_scheduler
)):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
upscale_noise_scheduler
))),
total
=
len
(
self
.
upscale_noise_scheduler
)
):
# i) define coefficients for time step t
# i) define coefficients for time step t
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
image_coeff
=
(
self
.
upscale_noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
clipped_coeff
=
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
upscale_noise_scheduler
.
get_beta
(
*
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha
(
t
))
t
)
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
)
clipped_coeff
=
(
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
upscale_noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
)
# ii) predict noise residual
# ii) predict noise residual
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
...
@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
...
@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
# iv) sample variance
prev_variance
=
self
.
upscale_noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
prev_variance
=
self
.
upscale_noise_scheduler
.
sample_variance
(
generator
=
generator
)
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
sampled_prev_image
=
prev_image
+
prev_variance
...
...
models/vision/glide/run_glide.py
View file @
528b1293
import
torch
import
torch
from
diffusers
import
DiffusionPipeline
import
PIL.Image
import
PIL.Image
from
diffusers
import
DiffusionPipeline
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
...
@@ -15,7 +17,7 @@ img = pipeline("a crayon drawing of a corgi", generator)
...
@@ -15,7 +17,7 @@ img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL
# process image to PIL
img
=
img
.
squeeze
(
0
)
img
=
img
.
squeeze
(
0
)
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
img
)
image_pil
=
PIL
.
Image
.
fromarray
(
img
)
# save image
# save image
...
...
setup.py
View file @
528b1293
...
@@ -84,6 +84,7 @@ _deps = [
...
@@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"numpy"
,
"numpy"
,
"pytest"
,
"pytest"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"torch>=1.4"
,
"torch>=1.4"
,
"torchvision"
,
"torchvision"
,
...
@@ -168,6 +169,7 @@ install_requires = [
...
@@ -168,6 +169,7 @@ install_requires = [
deps
[
"filelock"
],
deps
[
"filelock"
],
deps
[
"huggingface-hub"
],
deps
[
"huggingface-hub"
],
deps
[
"numpy"
],
deps
[
"numpy"
],
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"torch"
],
deps
[
"torchvision"
],
deps
[
"torchvision"
],
...
...
src/diffusers/__init__.py
View file @
528b1293
...
@@ -6,7 +6,7 @@ __version__ = "0.0.1"
...
@@ -6,7 +6,7 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDE
TextToImage
UNetModel
,
GLIDE
SuperRes
UNetModel
from
.models.unet_glide
import
GLIDE
SuperRes
UNetModel
,
GLIDE
TextToImage
UNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.vqvae
import
VQModel
from
.models.vqvae
import
VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
...
...
src/diffusers/configuration_utils.py
View file @
528b1293
...
@@ -23,13 +23,13 @@ import os
...
@@ -23,13 +23,13 @@ import os
import
re
import
re
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.
import
__version__
from
.utils
import
(
from
.utils
import
(
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
DIFFUSERS_CACHE
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
RevisionNotFoundError
,
...
@@ -37,9 +37,6 @@ from .utils import (
...
@@ -37,9 +37,6 @@ from .utils import (
)
)
from
.
import
__version__
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
...
@@ -95,9 +92,7 @@ class ConfigMixin:
...
@@ -95,9 +92,7 @@ class ConfigMixin:
@
classmethod
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
@@ -157,16 +152,16 @@ class ConfigMixin:
...
@@ -157,16 +152,16 @@ class ConfigMixin:
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed
on
"
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed"
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token
having
"
"
on
'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token"
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and
pass
"
"
having
permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
"`use_auth_token=True`."
"
pass
`use_auth_token=True`."
)
)
except
RevisionNotFoundError
:
except
RevisionNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for
this
"
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for"
f
"model name. Check the model page at
'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for
"
"
this
model name. Check the model page at"
"
available revisions."
f
" 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for
available revisions."
)
)
except
EntryNotFoundError
:
except
EntryNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
@@ -174,14 +169,16 @@ class ConfigMixin:
...
@@ -174,14 +169,16 @@ class ConfigMixin:
)
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
"There was a specific connection error when trying to load"
f
"
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
)
except
ValueError
:
except
ValueError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to run the"
f
" directory containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
)
except
EnvironmentError
:
except
EnvironmentError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
@@ -195,9 +192,7 @@ class ConfigMixin:
...
@@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
return
config_dict
return
config_dict
...
...
src/diffusers/dependency_versions_table.py
View file @
528b1293
...
@@ -3,29 +3,15 @@
...
@@ -3,29 +3,15 @@
# 2. run `make deps_table_update``
# 2. run `make deps_table_update``
deps
=
{
deps
=
{
"Pillow"
:
"Pillow"
,
"Pillow"
:
"Pillow"
,
"accelerate"
:
"accelerate>=0.9.0"
,
"black"
:
"black~=22.0,>=22.3"
,
"black"
:
"black~=22.0,>=22.3"
,
"codecarbon"
:
"codecarbon==1.2.0"
,
"filelock"
:
"filelock"
,
"dataclasses"
:
"dataclasses"
,
"flake8"
:
"flake8>=3.8.3"
,
"datasets"
:
"datasets"
,
"huggingface-hub"
:
"huggingface-hub"
,
"GitPython"
:
"GitPython<3.1.19"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.1.0,<1.0"
,
"importlib_metadata"
:
"importlib_metadata"
,
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"numpy"
:
"numpy
>=1.17
"
,
"numpy"
:
"numpy"
,
"pytest"
:
"pytest"
,
"pytest"
:
"pytest"
,
"pytest-timeout"
:
"pytest-timeout"
,
"pytest-xdist"
:
"pytest-xdist"
,
"python"
:
"python>=3.7.0"
,
"regex"
:
"regex!=2019.12.17"
,
"regex"
:
"regex!=2019.12.17"
,
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"sagemaker"
:
"sagemaker>=2.31.0"
,
"tokenizers"
:
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"torch"
:
"torch>=1.4"
,
"torch"
:
"torch>=1.4"
,
"torchaudio"
:
"torchaudio"
,
"torchvision"
:
"torchvision"
,
"tqdm"
:
"tqdm>=4.27"
,
"unidic"
:
"unidic>=1.0.2"
,
"unidic_lite"
:
"unidic_lite>=1.0.7"
,
"uvicorn"
:
"uvicorn"
,
}
}
src/diffusers/dynamic_modules_utils.py
View file @
528b1293
...
@@ -23,7 +23,8 @@ from pathlib import Path
...
@@ -23,7 +23,8 @@ from pathlib import Path
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
huggingface_hub
import
cached_download
from
huggingface_hub
import
cached_download
from
.utils
import
HF_MODULES_CACHE
,
DIFFUSERS_DYNAMIC_MODULE_NAME
,
logging
from
.utils
import
DIFFUSERS_DYNAMIC_MODULE_NAME
,
HF_MODULES_CACHE
,
logging
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
src/diffusers/modeling_utils.py
View file @
528b1293
...
@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
...
@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import
torch
import
torch
from
torch
import
Tensor
,
device
from
torch
import
Tensor
,
device
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.utils
import
(
from
.utils
import
(
CONFIG_NAME
,
CONFIG_NAME
,
...
@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
...
@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
)
except
EntryNotFoundError
:
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
"There was a specific connection error when trying to load"
f
"
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
)
except
ValueError
:
except
ValueError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
...
src/diffusers/models/__init__.py
View file @
528b1293
...
@@ -17,6 +17,6 @@
...
@@ -17,6 +17,6 @@
# limitations under the License.
# limitations under the License.
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDE
TextToImage
UNetModel
,
GLIDE
SuperRes
UNetModel
from
.unet_glide
import
GLIDE
SuperRes
UNetModel
,
GLIDE
TextToImage
UNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.vqvae
import
VQModel
from
.vqvae
import
VQModel
src/diffusers/models/unet.py
View file @
528b1293
...
@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
...
@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils
import
data
from
torch.utils
import
data
from
torchvision
import
transforms
,
utils
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
,
utils
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes
# dataset classes
class
Dataset
(
data
.
Dataset
):
class
Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
'
jpg
'
,
'
jpeg
'
,
'
png
'
]):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
"
jpg
"
,
"
jpeg
"
,
"
png
"
]):
super
().
__init__
()
super
().
__init__
()
self
.
folder
=
folder
self
.
folder
=
folder
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
'
{
folder
}
'
).
glob
(
f
'
**/*.
{
ext
}
'
)]
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
"
{
folder
}
"
).
glob
(
f
"
**/*.
{
ext
}
"
)]
self
.
transform
=
transforms
.
Compose
([
self
.
transform
=
transforms
.
Compose
(
[
transforms
.
Resize
(
image_size
),
transforms
.
Resize
(
image_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
CenterCrop
(
image_size
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
ToTensor
()
transforms
.
ToTensor
(),
])
]
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
paths
)
return
len
(
self
.
paths
)
...
@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
...
@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class
# trainer class
class
EMA
()
:
class
EMA
:
def
__init__
(
self
,
beta
):
def
__init__
(
self
,
beta
):
super
().
__init__
()
super
().
__init__
()
self
.
beta
=
beta
self
.
beta
=
beta
...
...
src/diffusers/models/unet_glide.py
View file @
528b1293
...
@@ -664,7 +664,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -664,7 +664,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
transformer_dim
=
512
transformer_dim
=
512
,
):
):
super
().
__init__
(
super
().
__init__
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
transformer_dim
=
transformer_dim
,
)
)
self
.
register
(
self
.
register
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
transformer_dim
=
transformer_dim
,
)
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
...
...
src/diffusers/models/unet_ldm.py
View file @
528b1293
from
inspect
import
isfunction
from
abc
import
abstractmethod
import
math
import
math
from
abc
import
abstractmethod
from
inspect
import
isfunction
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
try
:
try
:
from
einops
import
re
peat
,
rearrange
from
einops
import
re
arrange
,
repeat
except
:
except
:
print
(
"Einops is not installed"
)
print
(
"Einops is not installed"
)
pass
pass
...
@@ -16,12 +17,13 @@ except:
...
@@ -16,12 +17,13 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
def
exists
(
val
):
def
exists
(
val
):
return
val
is
not
None
return
val
is
not
None
def
uniq
(
arr
):
def
uniq
(
arr
):
return
{
el
:
True
for
el
in
arr
}.
keys
()
return
{
el
:
True
for
el
in
arr
}.
keys
()
def
default
(
val
,
d
):
def
default
(
val
,
d
):
...
@@ -53,20 +55,13 @@ class GEGLU(nn.Module):
...
@@ -53,20 +55,13 @@ class GEGLU(nn.Module):
class
FeedForward
(
nn
.
Module
):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
0
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
())
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
()
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
))
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
...
@@ -90,17 +85,17 @@ class LinearAttention(nn.Module):
...
@@ -90,17 +85,17 @@ class LinearAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
heads
=
heads
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'
b (qkv heads c) h w -> qkv b heads c (h w)
'
,
heads
=
self
.
heads
,
qkv
=
3
)
q
,
k
,
v
=
rearrange
(
qkv
,
"
b (qkv heads c) h w -> qkv b heads c (h w)
"
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'
bhdn,bhen->bhde
'
,
k
,
v
)
context
=
torch
.
einsum
(
"
bhdn,bhen->bhde
"
,
k
,
v
)
out
=
torch
.
einsum
(
'
bhde,bhdn->bhen
'
,
context
,
q
)
out
=
torch
.
einsum
(
"
bhde,bhdn->bhen
"
,
context
,
q
)
out
=
rearrange
(
out
,
'
b heads c (h w) -> b (heads c) h w
'
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
out
=
rearrange
(
out
,
"
b heads c (h w) -> b (heads c) h w
"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
@@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
...
@@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
in_channels
,
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
kernel_size
=
1
,
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
stride
=
1
,
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
h_
=
x
h_
=
x
...
@@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module):
...
@@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module):
v
=
self
.
v
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
b
,
c
,
h
,
w
=
q
.
shape
q
=
rearrange
(
q
,
'
b c h w -> b (h w) c
'
)
q
=
rearrange
(
q
,
"
b c h w -> b (h w) c
"
)
k
=
rearrange
(
k
,
'
b c h w -> b c (h w)
'
)
k
=
rearrange
(
k
,
"
b c h w -> b c (h w)
"
)
w_
=
torch
.
einsum
(
'
bij,bjk->bik
'
,
q
,
k
)
w_
=
torch
.
einsum
(
"
bij,bjk->bik
"
,
q
,
k
)
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
# attend to values
v
=
rearrange
(
v
,
'
b c h w -> b c (h w)
'
)
v
=
rearrange
(
v
,
"
b c h w -> b c (h w)
"
)
w_
=
rearrange
(
w_
,
'
b i j -> b j i
'
)
w_
=
rearrange
(
w_
,
"
b i j -> b j i
"
)
h_
=
torch
.
einsum
(
'
bij,bjk->bik
'
,
v
,
w_
)
h_
=
torch
.
einsum
(
"
bij,bjk->bik
"
,
v
,
w_
)
h_
=
rearrange
(
h_
,
'
b c (h w) -> b c h w
'
,
h
=
h
)
h_
=
rearrange
(
h_
,
"
b c (h w) -> b c h w
"
,
h
=
h
)
h_
=
self
.
proj_out
(
h_
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
return
x
+
h_
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.
0
):
super
().
__init__
()
super
().
__init__
()
inner_dim
=
dim_head
*
heads
inner_dim
=
dim_head
*
heads
context_dim
=
default
(
context_dim
,
query_dim
)
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**
-
0.5
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
self
.
heads
=
heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
context_dim
,
inner_dim
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
)
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
h
=
self
.
heads
h
=
self
.
heads
...
@@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
...
@@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'
b n (h d) -> (b h) n d
'
,
h
=
h
),
(
q
,
k
,
v
))
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"
b n (h d) -> (b h) n d
"
,
h
=
h
),
(
q
,
k
,
v
))
sim
=
torch
.
einsum
(
'
b i d, b j d -> b i j
'
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"
b i d, b j d -> b i j
"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
if
exists
(
mask
):
mask
=
rearrange
(
mask
,
'
b ... -> b (...)
'
)
mask
=
rearrange
(
mask
,
"
b ... -> b (...)
"
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
repeat
(
mask
,
'
b j -> (b h) () j
'
,
h
=
h
)
mask
=
repeat
(
mask
,
"
b j -> (b h) () j
"
,
h
=
h
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
'
b i j, b j d -> b i d
'
,
attn
,
v
)
out
=
torch
.
einsum
(
"
b i j, b j d -> b i d
"
,
attn
,
v
)
out
=
rearrange
(
out
,
'
(b h) n d -> b n (h d)
'
,
h
=
h
)
out
=
rearrange
(
out
,
"
(b h) n d -> b n (h d)
"
,
h
=
h
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.
0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
attn1
=
CrossAttention
(
query_dim
=
dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is a self-attention
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
ff
=
FeedForward
(
dim
,
dropout
=
dropout
,
glu
=
gated_ff
)
self
.
attn2
=
CrossAttention
(
query_dim
=
dim
,
context_dim
=
context_dim
,
self
.
attn2
=
CrossAttention
(
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
query_dim
=
dim
,
context_dim
=
context_dim
,
heads
=
n_heads
,
dim_head
=
d_head
,
dropout
=
dropout
)
# is self-attn if context is none
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
...
@@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
...
@@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Then apply standard transformer action.
Finally, reshape to image
Finally, reshape to image
"""
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
,
context_dim
=
None
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.
0
,
context_dim
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
inner_dim
=
n_heads
*
d_head
self
.
norm
=
Normalize
(
in_channels
)
self
.
norm
=
Normalize
(
in_channels
)
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
self
.
proj_in
=
nn
.
Conv2d
(
in_channels
,
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
inner_dim
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
transformer_blocks
=
nn
.
ModuleList
(
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
[
for
d
in
range
(
depth
)]
BasicTransformerBlock
(
inner_dim
,
n_heads
,
d_head
,
dropout
=
dropout
,
context_dim
=
context_dim
)
for
d
in
range
(
depth
)
]
)
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
def
forward
(
self
,
x
,
context
=
None
):
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
# note: if no context is given, cross-attention defaults to self-attention
...
@@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
...
@@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
x_in
=
x
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
rearrange
(
x
,
'
b c h w -> b (h w) c
'
)
x
=
rearrange
(
x
,
"
b c h w -> b (h w) c
"
)
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
block
(
x
,
context
=
context
)
x
=
rearrange
(
x
,
'
b (h w) c -> b c h w
'
,
h
=
h
,
w
=
w
)
x
=
rearrange
(
x
,
"
b (h w) c -> b c h w
"
,
h
=
h
,
w
=
w
)
x
=
self
.
proj_out
(
x
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
return
x
+
x_in
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
"""
"""
Convert primitive modules to float16.
Convert primitive modules to float16.
...
@@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module):
...
@@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module):
output_dim
:
int
=
None
,
output_dim
:
int
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
randn
(
embed_dim
,
spacial_dim
**
2
+
1
)
/
embed_dim
**
0.5
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
qkv_proj
=
conv_nd
(
1
,
embed_dim
,
3
*
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
c_proj
=
conv_nd
(
1
,
embed_dim
,
output_dim
or
embed_dim
,
1
)
self
.
num_heads
=
embed_dim
//
num_heads_channels
self
.
num_heads
=
embed_dim
//
num_heads_channels
...
@@ -453,9 +427,7 @@ class Upsample(nn.Module):
...
@@ -453,9 +427,7 @@ class Upsample(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
if
self
.
use_conv
:
...
@@ -472,7 +444,7 @@ class Downsample(nn.Module):
...
@@ -472,7 +444,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -480,9 +452,7 @@ class Downsample(nn.Module):
...
@@ -480,9 +452,7 @@ class Downsample(nn.Module):
self
.
dims
=
dims
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
if
use_conv
:
self
.
op
=
conv_nd
(
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
else
:
assert
self
.
channels
==
self
.
out_channels
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
...
@@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
...
@@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
normalization
(
self
.
out_channels
),
normalization
(
self
.
out_channels
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)),
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)
),
)
)
if
self
.
out_channels
==
channels
:
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
...
@@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y):
...
@@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
# the combination of the value vectors.
matmul_ops
=
2
*
b
*
(
num_spatial
**
2
)
*
c
matmul_ops
=
2
*
b
*
(
num_spatial
**
2
)
*
c
model
.
total_ops
+=
torch
.
DoubleTensor
([
matmul_ops
])
model
.
total_ops
+=
torch
.
DoubleTensor
([
matmul_ops
])
...
@@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
...
@@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
ch
=
width
//
(
3
*
self
.
n_heads
)
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
torch
.
einsum
(
weight
=
torch
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
...
@@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
)
if
use_spatial_transformer
:
if
use_spatial_transformer
:
assert
context_dim
is
not
None
,
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert
(
context_dim
is
not
None
),
"Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if
context_dim
is
not
None
:
if
context_dim
is
not
None
:
assert
use_spatial_transformer
,
'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert
(
use_spatial_transformer
),
"Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
num_heads_upsample
=
num_heads
if
num_heads
==
-
1
:
if
num_heads
==
-
1
:
assert
num_head_channels
!=
-
1
,
'
Either num_heads or num_head_channels has to be set
'
assert
num_head_channels
!=
-
1
,
"
Either num_heads or num_head_channels has to be set
"
if
num_head_channels
==
-
1
:
if
num_head_channels
==
-
1
:
assert
num_heads
!=
-
1
,
'
Either num_heads or num_head_channels has to be set
'
assert
num_heads
!=
-
1
,
"
Either num_heads or num_head_channels has to be set
"
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
...
@@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
)
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
)
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
input_blocks
=
nn
.
ModuleList
(
[
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
)
]
)
)
self
.
_feature_size
=
model_channels
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
input_block_chans
=
[
model_channels
]
...
@@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads
=
ch
//
num_head_channels
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
if
legacy
:
#num_heads = 1
#
num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
...
@@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
)
)
)
)
...
@@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
@@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads
=
ch
//
num_head_channels
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
if
legacy
:
#num_heads = 1
#
num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ResBlock
(
...
@@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
if
not
use_spatial_transformer
),
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
),
ResBlock
(
ResBlock
(
ch
,
ch
,
time_embed_dim
,
time_embed_dim
,
...
@@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads
=
ch
//
num_head_channels
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
if
legacy
:
#num_heads = 1
#
num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
layers
.
append
(
layers
.
append
(
AttentionBlock
(
AttentionBlock
(
...
@@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
dim_head
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
)
)
)
)
...
@@ -1026,7 +992,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -1026,7 +992,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
id_predictor
=
nn
.
Sequential
(
self
.
id_predictor
=
nn
.
Sequential
(
normalization
(
ch
),
normalization
(
ch
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
#
nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
#
nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
)
def
convert_to_fp16
(
self
):
def
convert_to_fp16
(
self
):
...
@@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param x: an [N x C x ...] Tensor of inputs.
...
@@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
...
@@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
use_new_attention_order
=
False
,
use_new_attention_order
=
False
,
pool
=
"adaptive"
,
pool
=
"adaptive"
,
*
args
,
*
args
,
**
kwargs
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
...
@@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
)
)
self
.
input_blocks
=
nn
.
ModuleList
(
self
.
input_blocks
=
nn
.
ModuleList
(
[
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
)
]
)
)
self
.
_feature_size
=
model_channels
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
input_block_chans
=
[
model_channels
]
...
@@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
...
@@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
down
=
True
,
down
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
)
)
ch
=
out_ch
ch
=
out_ch
...
@@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
...
@@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
),
normalization
(
ch
),
nn
.
SiLU
(),
nn
.
SiLU
(),
AttentionPool2d
(
AttentionPool2d
((
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
(
image_size
//
ds
),
ch
,
num_head_channels
,
out_channels
),
)
)
elif
pool
==
"spatial"
:
elif
pool
==
"spatial"
:
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
...
@@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
...
@@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
else
:
else
:
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
src/diffusers/pipeline_utils.py
View file @
528b1293
...
@@ -20,10 +20,9 @@ from typing import Optional, Union
...
@@ -20,10 +20,9 @@ from typing import Optional, Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
.utils
import
logging
,
DIFFUSERS_CACHE
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
from
.utils
import
DIFFUSERS_CACHE
,
logging
INDEX_FILE
=
"diffusion_model.pt"
INDEX_FILE
=
"diffusion_model.pt"
...
...
src/diffusers/schedulers/gaussian_ddpm.py
View file @
528b1293
...
@@ -11,12 +11,13 @@
...
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
math
import
math
import
torch
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
linear_beta_schedule
,
betas_for_alpha_bar
from
.schedulers_utils
import
betas_for_alpha_bar
,
linear_beta_schedule
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment