Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
528b1293
"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a6daf822de03b5a911b95253769036da22cbe25a"
Commit
528b1293
authored
Jun 09, 2022
by
anton-l
Browse files
make style
parents
f23bb3e8
cbb19ee8
Changes
23
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
321 additions
and
303 deletions
+321
-303
models/vision/ddim/example.py
models/vision/ddim/example.py
+5
-2
models/vision/ddim/modeling_ddim.py
models/vision/ddim/modeling_ddim.py
+17
-7
models/vision/ddim/run_inference.py
models/vision/ddim/run_inference.py
+4
-2
models/vision/ddpm/example.py
models/vision/ddpm/example.py
+17
-3
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+21
-7
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+17
-7
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+41
-25
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+5
-3
setup.py
setup.py
+2
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+19
-24
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+5
-19
src/diffusers/dynamic_modules_utils.py
src/diffusers/dynamic_modules_utils.py
+2
-1
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+6
-3
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+2
-2
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+13
-10
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+38
-38
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+101
-144
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+2
-3
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+3
-2
No files found.
models/vision/ddim/example.py
View file @
528b1293
#!/usr/bin/env python3
#!/usr/bin/env python3
import
os
import
os
import
pathlib
import
pathlib
from
modeling_ddim
import
DDIM
import
PIL.Image
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
from
modeling_ddim
import
DDIM
model_ids
=
[
"ddim-celeba-hq"
,
"ddim-lsun-church"
,
"ddim-lsun-bedroom"
]
model_ids
=
[
"ddim-celeba-hq"
,
"ddim-lsun-church"
,
"ddim-lsun-bedroom"
]
for
model_id
in
model_ids
:
for
model_id
in
model_ids
:
...
...
models/vision/ddim/modeling_ddim.py
View file @
528b1293
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
# limitations under the License.
# limitations under the License.
from
diffusers
import
DiffusionPipeline
import
tqdm
import
torch
import
torch
import
tqdm
from
diffusers
import
DiffusionPipeline
class
DDIM
(
DiffusionPipeline
):
class
DDIM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
...
@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
image
=
self
.
noise_scheduler
.
sample_noise
((
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
image
=
self
.
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
,
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# get actual t and t-1
# get actual t and t-1
train_step
=
inference_step_times
[
t
]
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
# compute alphas
# compute alphas
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
...
@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
...
@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
# compute relevant coefficients
# compute relevant coefficients
coeff_1
=
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
coeff_1
=
(
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
)
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
# model forward
# model forward
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
models/vision/ddim/run_inference.py
View file @
528b1293
#!/usr/bin/env python3
#!/usr/bin/env python3
# !pip install diffusers
# !pip install diffusers
from
modeling_ddim
import
DDIM
import
PIL.Image
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
from
modeling_ddim
import
DDIM
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-lsun-bedroom"
model_id
=
"fusing/ddpm-lsun-bedroom"
...
...
models/vision/ddpm/example.py
View file @
528b1293
#!/usr/bin/env python3
#!/usr/bin/env python3
import
os
import
os
import
pathlib
import
pathlib
from
modeling_ddpm
import
DDPM
import
PIL.Image
import
numpy
as
np
import
numpy
as
np
model_ids
=
[
"ddpm-lsun-cat"
,
"ddpm-lsun-cat-ema"
,
"ddpm-lsun-church-ema"
,
"ddpm-lsun-church"
,
"ddpm-lsun-bedroom"
,
"ddpm-lsun-bedroom-ema"
,
"ddpm-cifar10-ema"
,
"ddpm-cifar10"
,
"ddpm-celeba-hq"
,
"ddpm-celeba-hq-ema"
]
import
PIL.Image
from
modeling_ddpm
import
DDPM
model_ids
=
[
"ddpm-lsun-cat"
,
"ddpm-lsun-cat-ema"
,
"ddpm-lsun-church-ema"
,
"ddpm-lsun-church"
,
"ddpm-lsun-bedroom"
,
"ddpm-lsun-bedroom-ema"
,
"ddpm-cifar10-ema"
,
"ddpm-cifar10"
,
"ddpm-celeba-hq"
,
"ddpm-celeba-hq-ema"
,
]
for
model_id
in
model_ids
:
for
model_id
in
model_ids
:
path
=
os
.
path
.
join
(
"/home/patrick/images/hf"
,
model_id
)
path
=
os
.
path
.
join
(
"/home/patrick/images/hf"
,
model_id
)
...
...
models/vision/ddpm/modeling_ddpm.py
View file @
528b1293
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
# limitations under the License.
# limitations under the License.
from
diffusers
import
DiffusionPipeline
import
tqdm
import
torch
import
torch
import
tqdm
from
diffusers
import
DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
...
@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# 1. Sample gaussian noise
# 1. Sample gaussian noise
image
=
self
.
noise_scheduler
.
sample_noise
((
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
image
=
self
.
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
,
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
# i) define coefficients for time step t
# i) define coefficients for time step t
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
image_coeff
=
(
clipped_coeff
=
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
)
clipped_coeff
=
(
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
)
# ii) predict noise residual
# ii) predict noise residual
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
...
@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
# iv) sample variance
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
sampled_prev_image
=
prev_image
+
prev_variance
...
...
models/vision/glide/convert_weights.py
View file @
528b1293
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
diffusers
import
(
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
)
from
modeling_glide
import
GLIDE
,
CLIPTextModel
from
modeling_glide
import
GLIDE
,
CLIPTextModel
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
...
@@ -22,7 +27,9 @@ config = CLIPTextConfig(
...
@@ -22,7 +27,9 @@ config = CLIPTextConfig(
use_padding_embeddings
=
True
,
use_padding_embeddings
=
True
,
)
)
model
=
CLIPTextModel
(
config
).
eval
()
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/tokenizer/vocab.json"
,
"./glide-base/tokenizer/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
=
GPT2Tokenizer
(
"./glide-base/tokenizer/vocab.json"
,
"./glide-base/tokenizer/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
hf_encoder
=
model
.
text_model
hf_encoder
=
model
.
text_model
...
@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
...
@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler
=
GlideDDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
)
upscale_scheduler
=
GlideDDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
)
glide
=
GLIDE
(
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
glide
=
GLIDE
(
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
upscale_scheduler
)
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
upscale_scheduler
,
)
glide
.
save_pretrained
(
"./glide-base"
)
glide
.
save_pretrained
(
"./glide-base"
)
models/vision/glide/modeling_glide.py
View file @
528b1293
...
@@ -18,10 +18,20 @@ import math
...
@@ -18,10 +18,20 @@ import math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
import
tqdm
from
diffusers
import
(
ClassifierFreeGuidanceScheduler
,
DiffusionPipeline
,
GlideDDIMScheduler
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
)
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
,
GPT2Tokenizer
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
...
@@ -34,14 +44,6 @@ from transformers.utils import (
...
@@ -34,14 +44,6 @@ from transformers.utils import (
)
)
import
numpy
as
np
import
torch
import
tqdm
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
DiffusionPipeline
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
transformers
import
GPT2Tokenizer
#####################
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
#####################
#####################
...
@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
...
@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_noise_scheduler
:
GlideDDIMScheduler
upscale_noise_scheduler
:
GlideDDIMScheduler
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
self
.
register_modules
(
text_unet
=
text_unet
,
text_noise_scheduler
=
text_noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
text_unet
=
text_unet
,
upscale_unet
=
upscale_unet
,
upscale_noise_scheduler
=
upscale_noise_scheduler
text_noise_scheduler
=
text_noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_noise_scheduler
=
upscale_noise_scheduler
,
)
)
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
...
@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
posterior_log_variance_clipped
=
_extract_into_tensor
(
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
assert
(
posterior_mean
.
shape
[
0
]
posterior_mean
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
...
@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
...
@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp
=
0.997
upsample_temp
=
0.997
image
=
self
.
upscale_noise_scheduler
.
sample_noise
(
image
=
(
(
batch_size
,
3
,
256
,
256
),
device
=
torch_device
,
generator
=
generator
self
.
upscale_noise_scheduler
.
sample_noise
(
)
*
upsample_temp
(
batch_size
,
3
,
256
,
256
),
device
=
torch_device
,
generator
=
generator
)
*
upsample_temp
)
num_timesteps
=
len
(
self
.
upscale_noise_scheduler
)
num_timesteps
=
len
(
self
.
upscale_noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
upscale_noise_scheduler
))),
total
=
len
(
self
.
upscale_noise_scheduler
)):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
upscale_noise_scheduler
))),
total
=
len
(
self
.
upscale_noise_scheduler
)
):
# i) define coefficients for time step t
# i) define coefficients for time step t
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
image_coeff
=
(
self
.
upscale_noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
clipped_coeff
=
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
upscale_noise_scheduler
.
get_beta
(
*
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha
(
t
))
t
)
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
)
clipped_coeff
=
(
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
upscale_noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
)
# ii) predict noise residual
# ii) predict noise residual
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
...
@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
...
@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
# iv) sample variance
prev_variance
=
self
.
upscale_noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
prev_variance
=
self
.
upscale_noise_scheduler
.
sample_variance
(
generator
=
generator
)
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
sampled_prev_image
=
prev_image
+
prev_variance
...
...
models/vision/glide/run_glide.py
View file @
528b1293
import
torch
import
torch
from
diffusers
import
DiffusionPipeline
import
PIL.Image
import
PIL.Image
from
diffusers
import
DiffusionPipeline
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
...
@@ -15,8 +17,8 @@ img = pipeline("a crayon drawing of a corgi", generator)
...
@@ -15,8 +17,8 @@ img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL
# process image to PIL
img
=
img
.
squeeze
(
0
)
img
=
img
.
squeeze
(
0
)
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
img
)
image_pil
=
PIL
.
Image
.
fromarray
(
img
)
# save image
# save image
image_pil
.
save
(
"test.png"
)
image_pil
.
save
(
"test.png"
)
\ No newline at end of file
setup.py
View file @
528b1293
...
@@ -84,6 +84,7 @@ _deps = [
...
@@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"numpy"
,
"numpy"
,
"pytest"
,
"pytest"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"torch>=1.4"
,
"torch>=1.4"
,
"torchvision"
,
"torchvision"
,
...
@@ -168,6 +169,7 @@ install_requires = [
...
@@ -168,6 +169,7 @@ install_requires = [
deps
[
"filelock"
],
deps
[
"filelock"
],
deps
[
"huggingface-hub"
],
deps
[
"huggingface-hub"
],
deps
[
"numpy"
],
deps
[
"numpy"
],
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"torch"
],
deps
[
"torchvision"
],
deps
[
"torchvision"
],
...
...
src/diffusers/__init__.py
View file @
528b1293
...
@@ -6,7 +6,7 @@ __version__ = "0.0.1"
...
@@ -6,7 +6,7 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDE
TextToImage
UNetModel
,
GLIDE
SuperRes
UNetModel
from
.models.unet_glide
import
GLIDE
SuperRes
UNetModel
,
GLIDE
TextToImage
UNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.vqvae
import
VQModel
from
.models.vqvae
import
VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
...
...
src/diffusers/configuration_utils.py
View file @
528b1293
...
@@ -23,13 +23,13 @@ import os
...
@@ -23,13 +23,13 @@ import os
import
re
import
re
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.
import
__version__
from
.utils
import
(
from
.utils
import
(
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
DIFFUSERS_CACHE
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
RevisionNotFoundError
,
...
@@ -37,9 +37,6 @@ from .utils import (
...
@@ -37,9 +37,6 @@ from .utils import (
)
)
from
.
import
__version__
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
...
@@ -95,9 +92,7 @@ class ConfigMixin:
...
@@ -95,9 +92,7 @@ class ConfigMixin:
@
classmethod
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
@@ -157,16 +152,16 @@ class ConfigMixin:
...
@@ -157,16 +152,16 @@ class ConfigMixin:
except
RepositoryNotFoundError
:
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed
on
"
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed"
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token
having
"
"
on
'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token"
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and
pass
"
"
having
permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
"`use_auth_token=True`."
"
pass
`use_auth_token=True`."
)
)
except
RevisionNotFoundError
:
except
RevisionNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for
this
"
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for"
f
"model name. Check the model page at
'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for
"
"
this
model name. Check the model page at"
"
available revisions."
f
" 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for
available revisions."
)
)
except
EntryNotFoundError
:
except
EntryNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
@@ -174,14 +169,16 @@ class ConfigMixin:
...
@@ -174,14 +169,16 @@ class ConfigMixin:
)
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
"There was a specific connection error when trying to load"
f
"
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
)
except
ValueError
:
except
ValueError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to run the"
f
" directory containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
)
except
EnvironmentError
:
except
EnvironmentError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
@@ -195,9 +192,7 @@ class ConfigMixin:
...
@@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
return
config_dict
return
config_dict
...
...
src/diffusers/dependency_versions_table.py
View file @
528b1293
...
@@ -3,29 +3,15 @@
...
@@ -3,29 +3,15 @@
# 2. run `make deps_table_update``
# 2. run `make deps_table_update``
deps
=
{
deps
=
{
"Pillow"
:
"Pillow"
,
"Pillow"
:
"Pillow"
,
"accelerate"
:
"accelerate>=0.9.0"
,
"black"
:
"black~=22.0,>=22.3"
,
"black"
:
"black~=22.0,>=22.3"
,
"codecarbon"
:
"codecarbon==1.2.0"
,
"filelock"
:
"filelock"
,
"dataclasses"
:
"dataclasses"
,
"flake8"
:
"flake8>=3.8.3"
,
"datasets"
:
"datasets"
,
"huggingface-hub"
:
"huggingface-hub"
,
"GitPython"
:
"GitPython<3.1.19"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.1.0,<1.0"
,
"importlib_metadata"
:
"importlib_metadata"
,
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"numpy"
:
"numpy
>=1.17
"
,
"numpy"
:
"numpy"
,
"pytest"
:
"pytest"
,
"pytest"
:
"pytest"
,
"pytest-timeout"
:
"pytest-timeout"
,
"pytest-xdist"
:
"pytest-xdist"
,
"python"
:
"python>=3.7.0"
,
"regex"
:
"regex!=2019.12.17"
,
"regex"
:
"regex!=2019.12.17"
,
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"sagemaker"
:
"sagemaker>=2.31.0"
,
"tokenizers"
:
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"torch"
:
"torch>=1.4"
,
"torch"
:
"torch>=1.4"
,
"torchaudio"
:
"torchaudio"
,
"torchvision"
:
"torchvision"
,
"tqdm"
:
"tqdm>=4.27"
,
"unidic"
:
"unidic>=1.0.2"
,
"unidic_lite"
:
"unidic_lite>=1.0.7"
,
"uvicorn"
:
"uvicorn"
,
}
}
src/diffusers/dynamic_modules_utils.py
View file @
528b1293
...
@@ -23,7 +23,8 @@ from pathlib import Path
...
@@ -23,7 +23,8 @@ from pathlib import Path
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
huggingface_hub
import
cached_download
from
huggingface_hub
import
cached_download
from
.utils
import
HF_MODULES_CACHE
,
DIFFUSERS_DYNAMIC_MODULE_NAME
,
logging
from
.utils
import
DIFFUSERS_DYNAMIC_MODULE_NAME
,
HF_MODULES_CACHE
,
logging
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
src/diffusers/modeling_utils.py
View file @
528b1293
...
@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
...
@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import
torch
import
torch
from
torch
import
Tensor
,
device
from
torch
import
Tensor
,
device
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.utils
import
(
from
.utils
import
(
CONFIG_NAME
,
CONFIG_NAME
,
...
@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
...
@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
)
except
EntryNotFoundError
:
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
except
HTTPError
as
err
:
except
HTTPError
as
err
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
"There was a specific connection error when trying to load"
f
"
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
)
except
ValueError
:
except
ValueError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
...
...
src/diffusers/models/__init__.py
View file @
528b1293
...
@@ -17,6 +17,6 @@
...
@@ -17,6 +17,6 @@
# limitations under the License.
# limitations under the License.
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDE
TextToImage
UNetModel
,
GLIDE
SuperRes
UNetModel
from
.unet_glide
import
GLIDE
SuperRes
UNetModel
,
GLIDE
TextToImage
UNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.vqvae
import
VQModel
from
.vqvae
import
VQModel
\ No newline at end of file
src/diffusers/models/unet.py
View file @
528b1293
...
@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
...
@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils
import
data
from
torch.utils
import
data
from
torchvision
import
transforms
,
utils
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
,
utils
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes
# dataset classes
class
Dataset
(
data
.
Dataset
):
class
Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
'
jpg
'
,
'
jpeg
'
,
'
png
'
]):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
"
jpg
"
,
"
jpeg
"
,
"
png
"
]):
super
().
__init__
()
super
().
__init__
()
self
.
folder
=
folder
self
.
folder
=
folder
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
'
{
folder
}
'
).
glob
(
f
'
**/*.
{
ext
}
'
)]
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
"
{
folder
}
"
).
glob
(
f
"
**/*.
{
ext
}
"
)]
self
.
transform
=
transforms
.
Compose
([
self
.
transform
=
transforms
.
Compose
(
transforms
.
Resize
(
image_size
),
[
transforms
.
RandomHorizontalFlip
(),
transforms
.
Resize
(
image_size
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
()
transforms
.
CenterCrop
(
image_size
),
])
transforms
.
ToTensor
(),
]
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
paths
)
return
len
(
self
.
paths
)
...
@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
...
@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class
# trainer class
class
EMA
()
:
class
EMA
:
def
__init__
(
self
,
beta
):
def
__init__
(
self
,
beta
):
super
().
__init__
()
super
().
__init__
()
self
.
beta
=
beta
self
.
beta
=
beta
...
...
src/diffusers/models/unet_glide.py
View file @
528b1293
...
@@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
=
3
,
in_channels
=
3
,
model_channels
=
192
,
model_channels
=
192
,
out_channels
=
6
,
out_channels
=
6
,
num_res_blocks
=
3
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
conv_resample
=
True
,
dims
=
2
,
dims
=
2
,
use_checkpoint
=
False
,
use_checkpoint
=
False
,
use_fp16
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
transformer_dim
=
512
transformer_dim
=
512
,
):
):
super
().
__init__
(
super
().
__init__
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
transformer_dim
=
transformer_dim
,
)
)
self
.
register
(
self
.
register
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
transformer_dim
=
transformer_dim
,
)
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
...
@@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
...
@@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
=
3
,
in_channels
=
3
,
model_channels
=
192
,
model_channels
=
192
,
out_channels
=
6
,
out_channels
=
6
,
num_res_blocks
=
3
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
conv_resample
=
True
,
dims
=
2
,
dims
=
2
,
use_checkpoint
=
False
,
use_checkpoint
=
False
,
use_fp16
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
):
):
super
().
__init__
(
super
().
__init__
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
...
@@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
...
@@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
\ No newline at end of file
src/diffusers/models/unet_ldm.py
View file @
528b1293
This diff is collapsed.
Click to expand it.
src/diffusers/pipeline_utils.py
View file @
528b1293
...
@@ -20,10 +20,9 @@ from typing import Optional, Union
...
@@ -20,10 +20,9 @@ from typing import Optional, Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
.utils
import
logging
,
DIFFUSERS_CACHE
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
from
.utils
import
DIFFUSERS_CACHE
,
logging
INDEX_FILE
=
"diffusion_model.pt"
INDEX_FILE
=
"diffusion_model.pt"
...
@@ -105,7 +104,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -105,7 +104,7 @@ class DiffusionPipeline(ConfigMixin):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
r
"""
r
"""
Add docstrings
Add docstrings
"""
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
...
...
src/diffusers/schedulers/gaussian_ddpm.py
View file @
528b1293
...
@@ -11,12 +11,13 @@
...
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
math
import
math
import
torch
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
linear_beta_schedule
,
betas_for_alpha_bar
from
.schedulers_utils
import
betas_for_alpha_bar
,
linear_beta_schedule
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment