Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
17c574a1
Commit
17c574a1
authored
Jun 15, 2022
by
Patrick von Platen
Browse files
remove torchvision dependency
parent
f84bbd35
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
174 additions
and
299 deletions
+174
-299
examples/train_ddpm.py
examples/train_ddpm.py
+5
-3
setup.py
setup.py
+0
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+4
-4
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+1
-1
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+0
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+2
-2
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-169
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+41
-41
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+5
-5
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+2
-2
src/diffusers/pipelines/pipeline_bddm.py
src/diffusers/pipelines/pipeline_bddm.py
+1
-1
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+1
-3
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+92
-61
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+1
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+7
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+12
-1
No files found.
examples/train_ddpm.py
View file @
17c574a1
...
@@ -144,9 +144,11 @@ if __name__ == "__main__":
...
@@ -144,9 +144,11 @@ if __name__ == "__main__":
type
=
str
,
type
=
str
,
default
=
"no"
,
default
=
"no"
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
"Whether to use mixed precision. Choose"
help
=
(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
,
"and an Nvidia Ampere GPU."
),
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
setup.py
View file @
17c574a1
...
@@ -87,7 +87,6 @@ _deps = [
...
@@ -87,7 +87,6 @@ _deps = [
"regex!=2019.12.17"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"torch>=1.4"
,
"torch>=1.4"
,
"torchvision"
,
]
]
# this is a lookup table with items like:
# this is a lookup table with items like:
...
@@ -172,7 +171,6 @@ install_requires = [
...
@@ -172,7 +171,6 @@ install_requires = [
deps
[
"regex"
],
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"torch"
],
deps
[
"torchvision"
],
deps
[
"Pillow"
],
deps
[
"Pillow"
],
]
]
...
...
src/diffusers/__init__.py
View file @
17c574a1
...
@@ -6,10 +6,10 @@ __version__ = "0.0.3"
...
@@ -6,10 +6,10 @@ __version__ = "0.0.3"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
GLIDEUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
Scheduler
Mixin
,
PNDM
Scheduler
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
PNDM
Scheduler
,
Scheduler
Mixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/configuration_utils.py
View file @
17c574a1
src/diffusers/dependency_versions_table.py
View file @
17c574a1
...
@@ -13,5 +13,4 @@ deps = {
...
@@ -13,5 +13,4 @@ deps = {
"regex"
:
"regex!=2019.12.17"
,
"regex"
:
"regex!=2019.12.17"
,
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"torch"
:
"torch>=1.4"
,
"torch"
:
"torch>=1.4"
,
"torchvision"
:
"torchvision"
,
}
}
src/diffusers/models/__init__.py
View file @
17c574a1
...
@@ -17,6 +17,6 @@
...
@@ -17,6 +17,6 @@
# limitations under the License.
# limitations under the License.
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
GLIDEUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_grad_tts
import
UNetGradTTSModel
from
.unet_grad_tts
import
UNetGradTTSModel
from
.unet_ldm
import
UNetLDMModel
src/diffusers/models/unet.py
View file @
17c574a1
...
@@ -26,7 +26,6 @@ from torch.optim import Adam
...
@@ -26,7 +26,6 @@ from torch.optim import Adam
from
torch.utils
import
data
from
torch.utils
import
data
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
,
utils
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin):
h
=
nonlinearity
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
return
h
# dataset classes
class
Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
"jpg"
,
"jpeg"
,
"png"
]):
super
().
__init__
()
self
.
folder
=
folder
self
.
image_size
=
image_size
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
"
{
folder
}
"
).
glob
(
f
"**/*.
{
ext
}
"
)]
self
.
transform
=
transforms
.
Compose
(
[
transforms
.
Resize
(
image_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
CenterCrop
(
image_size
),
transforms
.
ToTensor
(),
]
)
def
__len__
(
self
):
return
len
(
self
.
paths
)
def
__getitem__
(
self
,
index
):
path
=
self
.
paths
[
index
]
img
=
Image
.
open
(
path
)
return
self
.
transform
(
img
)
# trainer class
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
class
Trainer
(
object
):
def
__init__
(
self
,
diffusion_model
,
folder
,
*
,
ema_decay
=
0.995
,
image_size
=
128
,
train_batch_size
=
32
,
train_lr
=
1e-4
,
train_num_steps
=
100000
,
gradient_accumulate_every
=
2
,
amp
=
False
,
step_start_ema
=
2000
,
update_ema_every
=
10
,
save_and_sample_every
=
1000
,
results_folder
=
"./results"
,
):
super
().
__init__
()
self
.
model
=
diffusion_model
self
.
ema
=
EMA
(
ema_decay
)
self
.
ema_model
=
copy
.
deepcopy
(
self
.
model
)
self
.
update_ema_every
=
update_ema_every
self
.
step_start_ema
=
step_start_ema
self
.
save_and_sample_every
=
save_and_sample_every
self
.
batch_size
=
train_batch_size
self
.
image_size
=
diffusion_model
.
image_size
self
.
gradient_accumulate_every
=
gradient_accumulate_every
self
.
train_num_steps
=
train_num_steps
self
.
ds
=
Dataset
(
folder
,
image_size
)
self
.
dl
=
cycle
(
data
.
DataLoader
(
self
.
ds
,
batch_size
=
train_batch_size
,
shuffle
=
True
,
pin_memory
=
True
))
self
.
opt
=
Adam
(
diffusion_model
.
parameters
(),
lr
=
train_lr
)
self
.
step
=
0
self
.
amp
=
amp
self
.
scaler
=
GradScaler
(
enabled
=
amp
)
self
.
results_folder
=
Path
(
results_folder
)
self
.
results_folder
.
mkdir
(
exist_ok
=
True
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
self
.
ema_model
.
load_state_dict
(
self
.
model
.
state_dict
())
def
step_ema
(
self
):
if
self
.
step
<
self
.
step_start_ema
:
self
.
reset_parameters
()
return
self
.
ema
.
update_model_average
(
self
.
ema_model
,
self
.
model
)
def
save
(
self
,
milestone
):
data
=
{
"step"
:
self
.
step
,
"model"
:
self
.
model
.
state_dict
(),
"ema"
:
self
.
ema_model
.
state_dict
(),
"scaler"
:
self
.
scaler
.
state_dict
(),
}
torch
.
save
(
data
,
str
(
self
.
results_folder
/
f
"model-
{
milestone
}
.pt"
))
def
load
(
self
,
milestone
):
data
=
torch
.
load
(
str
(
self
.
results_folder
/
f
"model-
{
milestone
}
.pt"
))
self
.
step
=
data
[
"step"
]
self
.
model
.
load_state_dict
(
data
[
"model"
])
self
.
ema_model
.
load_state_dict
(
data
[
"ema"
])
self
.
scaler
.
load_state_dict
(
data
[
"scaler"
])
def
train
(
self
):
with
tqdm
(
initial
=
self
.
step
,
total
=
self
.
train_num_steps
)
as
pbar
:
while
self
.
step
<
self
.
train_num_steps
:
for
i
in
range
(
self
.
gradient_accumulate_every
):
data
=
next
(
self
.
dl
).
cuda
()
with
autocast
(
enabled
=
self
.
amp
):
loss
=
self
.
model
(
data
)
self
.
scaler
.
scale
(
loss
/
self
.
gradient_accumulate_every
).
backward
()
pbar
.
set_description
(
f
"loss:
{
loss
.
item
():.
4
f
}
"
)
self
.
scaler
.
step
(
self
.
opt
)
self
.
scaler
.
update
()
self
.
opt
.
zero_grad
()
if
self
.
step
%
self
.
update_ema_every
==
0
:
self
.
step_ema
()
if
self
.
step
!=
0
and
self
.
step
%
self
.
save_and_sample_every
==
0
:
self
.
ema_model
.
eval
()
milestone
=
self
.
step
//
self
.
save_and_sample_every
batches
=
num_to_groups
(
36
,
self
.
batch_size
)
all_images_list
=
list
(
map
(
lambda
n
:
self
.
ema_model
.
sample
(
batch_size
=
n
),
batches
))
all_images
=
torch
.
cat
(
all_images_list
,
dim
=
0
)
utils
.
save_image
(
all_images
,
str
(
self
.
results_folder
/
f
"sample-
{
milestone
}
.png"
),
nrow
=
6
)
self
.
save
(
milestone
)
self
.
step
+=
1
pbar
.
update
(
1
)
print
(
"training complete"
)
src/diffusers/models/unet_grad_tts.py
View file @
17c574a1
...
@@ -2,6 +2,7 @@ import math
...
@@ -2,6 +2,7 @@ import math
import
torch
import
torch
try
:
try
:
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
except
:
except
:
...
@@ -11,6 +12,7 @@ except:
...
@@ -11,6 +12,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
...
@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module):
...
@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module):
class
Block
(
torch
.
nn
.
Module
):
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
self
.
block
=
torch
.
nn
.
Sequential
(
padding
=
1
),
torch
.
nn
.
GroupNorm
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
()
groups
,
dim_out
),
Mish
()
)
)
def
forward
(
self
,
x
,
mask
):
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
output
=
self
.
block
(
x
*
mask
)
...
@@ -59,8 +61,7 @@ class Block(torch.nn.Module):
...
@@ -59,8 +61,7 @@ class Block(torch.nn.Module):
class
ResnetBlock
(
torch
.
nn
.
Module
):
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlock
,
self
).
__init__
()
super
(
ResnetBlock
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
...
@@ -88,13 +89,11 @@ class LinearAttention(torch.nn.Module):
...
@@ -88,13 +89,11 @@ class LinearAttention(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b (qkv heads c) h w -> qkv b heads c (h w)'
,
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'bhdn,bhen->bhde'
,
k
,
v
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
'bhde,bhdn->bhen'
,
context
,
q
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
rearrange
(
out
,
'b heads c (h w) -> b (heads c) h w'
,
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module):
...
@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module):
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
super
(
UNetGradTTSModel
,
self
).
__init__
()
self
.
register
(
self
.
register
(
...
@@ -143,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -143,7 +133,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
n_spks
=
n_spks
,
n_spks
=
n_spks
,
spk_emb_dim
=
spk_emb_dim
,
spk_emb_dim
=
spk_emb_dim
,
n_feats
=
n_feats
,
n_feats
=
n_feats
,
pe_scale
=
pe_scale
pe_scale
=
pe_scale
,
)
)
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -154,11 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -154,11 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
pe_scale
=
pe_scale
self
.
pe_scale
=
pe_scale
if
n_spks
>
1
:
if
n_spks
>
1
:
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
))
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
...
@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
([
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
(
[
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
()]))
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
(),
]
)
)
mid_dim
=
dims
[
-
1
]
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_block1
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
...
@@ -180,11 +175,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -180,11 +175,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
mid_block2
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_block2
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
([
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
(
[
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
)]))
Upsample
(
dim_in
),
]
)
)
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
...
...
src/diffusers/pipeline_utils.py
View file @
17c574a1
src/diffusers/pipelines/__init__.py
View file @
17c574a1
from
.pipeline_bddm
import
BDDM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_pndm
import
PNDM
from
.pipeline_glide
import
GLIDE
from
.pipeline_glide
import
GLIDE
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_
bd
dm
import
BD
DM
from
.pipeline_
pn
dm
import
PN
DM
src/diffusers/pipelines/pipeline_bddm.py
View file @
17c574a1
src/diffusers/pipelines/pipeline_glide.py
View file @
17c574a1
...
@@ -832,9 +832,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -832,9 +832,7 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
# 1. Sample gaussian noise
batch_size
=
2
# second image is empty for classifier-free guidance
batch_size
=
2
# second image is empty for classifier-free guidance
image
=
torch
.
randn
(
image
=
torch
.
randn
((
batch_size
,
self
.
text_unet
.
in_channels
,
64
,
64
),
generator
=
generator
).
to
(
torch_device
)
(
batch_size
,
self
.
text_unet
.
in_channels
,
64
,
64
),
generator
=
generator
).
to
(
torch_device
)
# 2. Encode tokens
# 2. Encode tokens
# an empty input is needed to guide the model away from it
# an empty input is needed to guide the model away from it
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
17c574a1
...
@@ -39,14 +39,13 @@ def generate_path(duration, mask):
...
@@ -39,14 +39,13 @@ def generate_path(duration, mask):
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
*
mask
path
=
path
*
mask
return
path
return
path
def
duration_loss
(
logw
,
logw_
,
lengths
):
def
duration_loss
(
logw
,
logw_
,
lengths
):
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
return
loss
return
loss
...
@@ -62,7 +61,7 @@ class LayerNorm(nn.Module):
...
@@ -62,7 +61,7 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
n_dims
=
len
(
x
.
shape
)
n_dims
=
len
(
x
.
shape
)
mean
=
torch
.
mean
(
x
,
1
,
keepdim
=
True
)
mean
=
torch
.
mean
(
x
,
1
,
keepdim
=
True
)
variance
=
torch
.
mean
((
x
-
mean
)
**
2
,
1
,
keepdim
=
True
)
variance
=
torch
.
mean
((
x
-
mean
)
**
2
,
1
,
keepdim
=
True
)
x
=
(
x
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
x
=
(
x
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
...
@@ -72,8 +71,7 @@ class LayerNorm(nn.Module):
...
@@ -72,8 +71,7 @@ class LayerNorm(nn.Module):
class
ConvReluNorm
(
nn
.
Module
):
class
ConvReluNorm
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_channels
,
out_channels
,
kernel_size
,
def
__init__
(
self
,
in_channels
,
hidden_channels
,
out_channels
,
kernel_size
,
n_layers
,
p_dropout
):
n_layers
,
p_dropout
):
super
(
ConvReluNorm
,
self
).
__init__
()
super
(
ConvReluNorm
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
hidden_channels
=
hidden_channels
self
.
hidden_channels
=
hidden_channels
...
@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module):
...
@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module):
self
.
conv_layers
=
torch
.
nn
.
ModuleList
()
self
.
conv_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers
=
torch
.
nn
.
ModuleList
()
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
in_channels
,
hidden_channels
,
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
in_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
relu_drop
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
p_dropout
))
self
.
relu_drop
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
p_dropout
))
for
_
in
range
(
n_layers
-
1
):
for
_
in
range
(
n_layers
-
1
):
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
hidden_channels
,
hidden_channels
,
self
.
conv_layers
.
append
(
kernel_size
,
padding
=
kernel_size
//
2
))
torch
.
nn
.
Conv1d
(
hidden_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
)
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
proj
=
torch
.
nn
.
Conv1d
(
hidden_channels
,
out_channels
,
1
)
self
.
proj
=
torch
.
nn
.
Conv1d
(
hidden_channels
,
out_channels
,
1
)
self
.
proj
.
weight
.
data
.
zero_
()
self
.
proj
.
weight
.
data
.
zero_
()
...
@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module):
...
@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module):
self
.
p_dropout
=
p_dropout
self
.
p_dropout
=
p_dropout
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_1
=
LayerNorm
(
filter_channels
)
self
.
norm_1
=
LayerNorm
(
filter_channels
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
filter_channels
,
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_2
=
LayerNorm
(
filter_channels
)
self
.
norm_2
=
LayerNorm
(
filter_channels
)
self
.
proj
=
torch
.
nn
.
Conv1d
(
filter_channels
,
1
,
1
)
self
.
proj
=
torch
.
nn
.
Conv1d
(
filter_channels
,
1
,
1
)
...
@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module):
...
@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module):
class
MultiHeadAttention
(
nn
.
Module
):
class
MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
out_channels
,
n_heads
,
window_size
=
None
,
def
__init__
(
heads_share
=
True
,
p_dropout
=
0.0
,
proximal_bias
=
False
,
self
,
proximal_init
=
False
):
channels
,
out_channels
,
n_heads
,
window_size
=
None
,
heads_share
=
True
,
p_dropout
=
0.0
,
proximal_bias
=
False
,
proximal_init
=
False
,
):
super
(
MultiHeadAttention
,
self
).
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
assert
channels
%
n_heads
==
0
assert
channels
%
n_heads
==
0
...
@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module):
...
@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module):
if
window_size
is
not
None
:
if
window_size
is
not
None
:
n_heads_rel
=
1
if
heads_share
else
n_heads
n_heads_rel
=
1
if
heads_share
else
n_heads
rel_stddev
=
self
.
k_channels
**-
0.5
rel_stddev
=
self
.
k_channels
**-
0.5
self
.
emb_rel_k
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
self
.
emb_rel_k
=
torch
.
nn
.
Parameter
(
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
self
.
emb_rel_v
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
)
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
emb_rel_v
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
conv_o
=
torch
.
nn
.
Conv1d
(
channels
,
out_channels
,
1
)
self
.
conv_o
=
torch
.
nn
.
Conv1d
(
channels
,
out_channels
,
1
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
...
@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module):
scores
=
scores
+
scores_local
scores
=
scores
+
scores_local
if
self
.
proximal_bias
:
if
self
.
proximal_bias
:
assert
t_s
==
t_t
,
"Proximal bias is only available for self-attention."
assert
t_s
==
t_t
,
"Proximal bias is only available for self-attention."
scores
=
scores
+
self
.
_attention_bias_proximal
(
t_s
).
to
(
device
=
scores
.
device
,
scores
=
scores
+
self
.
_attention_bias_proximal
(
t_s
).
to
(
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
dtype
=
scores
.
dtype
)
if
mask
is
not
None
:
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e4
)
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e4
)
p_attn
=
torch
.
nn
.
functional
.
softmax
(
scores
,
dim
=-
1
)
p_attn
=
torch
.
nn
.
functional
.
softmax
(
scores
,
dim
=-
1
)
...
@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module):
if
self
.
window_size
is
not
None
:
if
self
.
window_size
is
not
None
:
relative_weights
=
self
.
_absolute_position_to_relative_position
(
p_attn
)
relative_weights
=
self
.
_absolute_position_to_relative_position
(
p_attn
)
value_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_v
,
t_s
)
value_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_v
,
t_s
)
output
=
output
+
self
.
_matmul_with_relative_values
(
relative_weights
,
output
=
output
+
self
.
_matmul_with_relative_values
(
relative_weights
,
value_relative_embeddings
)
value_relative_embeddings
)
output
=
output
.
transpose
(
2
,
3
).
contiguous
().
view
(
b
,
d
,
t_t
)
output
=
output
.
transpose
(
2
,
3
).
contiguous
().
view
(
b
,
d
,
t_t
)
return
output
,
p_attn
return
output
,
p_attn
...
@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module):
...
@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module):
slice_end_position
=
slice_start_position
+
2
*
length
-
1
slice_end_position
=
slice_start_position
+
2
*
length
-
1
if
pad_length
>
0
:
if
pad_length
>
0
:
padded_relative_embeddings
=
torch
.
nn
.
functional
.
pad
(
padded_relative_embeddings
=
torch
.
nn
.
functional
.
pad
(
relative_embeddings
,
convert_pad_shape
([[
0
,
0
],
relative_embeddings
,
convert_pad_shape
([[
0
,
0
],
[
pad_length
,
pad_length
],
[
0
,
0
]])
[
pad_length
,
pad_length
],
[
0
,
0
]])
)
)
else
:
else
:
padded_relative_embeddings
=
relative_embeddings
padded_relative_embeddings
=
relative_embeddings
used_relative_embeddings
=
padded_relative_embeddings
[:,
used_relative_embeddings
=
padded_relative_embeddings
[:,
slice_start_position
:
slice_end_position
]
slice_start_position
:
slice_end_position
]
return
used_relative_embeddings
return
used_relative_embeddings
def
_relative_position_to_absolute_position
(
self
,
x
):
def
_relative_position_to_absolute_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
0
],[
0
,
1
]]))
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
*
2
*
length
])
x_flat
=
x
.
view
([
batch
,
heads
,
length
*
2
*
length
])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
length
-
1
]]))
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
+
1
,
2
*
length
-
1
])[:,
:,
:
length
,
length
-
1
:]
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
+
1
,
2
*
length
-
1
])[:,
:,
:
length
,
length
-
1
:]
return
x_final
return
x_final
def
_absolute_position_to_relative_position
(
self
,
x
):
def
_absolute_position_to_relative_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]]))
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
**
2
+
length
*
(
length
-
1
)])
x_flat
=
x
.
view
([
batch
,
heads
,
length
**
2
+
length
*
(
length
-
1
)])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
length
,
0
]]))
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
length
,
0
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
,
2
*
length
])[:,:,:,
1
:]
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
,
2
*
length
])[:,
:,
:,
1
:]
return
x_final
return
x_final
def
_attention_bias_proximal
(
self
,
length
):
def
_attention_bias_proximal
(
self
,
length
):
...
@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module):
class
FFN
(
nn
.
Module
):
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_channels
,
kernel_size
,
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
0.0
):
p_dropout
=
0.0
):
super
(
FFN
,
self
).
__init__
()
super
(
FFN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -267,10 +269,8 @@ class FFN(nn.Module):
...
@@ -267,10 +269,8 @@ class FFN(nn.Module):
self
.
kernel_size
=
kernel_size
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
p_dropout
=
p_dropout
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
padding
=
kernel_size
//
2
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
def
forward
(
self
,
x
,
x_mask
):
def
forward
(
self
,
x
,
x_mask
):
...
@@ -282,8 +282,17 @@ class FFN(nn.Module):
...
@@ -282,8 +282,17 @@ class FFN(nn.Module):
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
def
__init__
(
kernel_size
=
1
,
p_dropout
=
0.0
,
window_size
=
None
,
**
kwargs
):
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
kernel_size
=
1
,
p_dropout
=
0.0
,
window_size
=
None
,
**
kwargs
,
):
super
(
Encoder
,
self
).
__init__
()
super
(
Encoder
,
self
).
__init__
()
self
.
hidden_channels
=
hidden_channels
self
.
hidden_channels
=
hidden_channels
self
.
filter_channels
=
filter_channels
self
.
filter_channels
=
filter_channels
...
@@ -299,11 +308,15 @@ class Encoder(nn.Module):
...
@@ -299,11 +308,15 @@ class Encoder(nn.Module):
self
.
ffn_layers
=
torch
.
nn
.
ModuleList
()
self
.
ffn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_2
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_2
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
self
.
n_layers
):
for
_
in
range
(
self
.
n_layers
):
self
.
attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
self
.
attn_layers
.
append
(
n_heads
,
window_size
=
window_size
,
p_dropout
=
p_dropout
))
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
window_size
=
window_size
,
p_dropout
=
p_dropout
)
)
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
ffn_layers
.
append
(
FFN
(
hidden_channels
,
hidden_channels
,
self
.
ffn_layers
.
append
(
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
))
FFN
(
hidden_channels
,
hidden_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
)
)
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
def
forward
(
self
,
x
,
x_mask
):
def
forward
(
self
,
x
,
x_mask
):
...
@@ -321,9 +334,21 @@ class Encoder(nn.Module):
...
@@ -321,9 +334,21 @@ class Encoder(nn.Module):
class
TextEncoder
(
ModelMixin
,
ConfigMixin
):
class
TextEncoder
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
n_vocab
,
n_feats
,
n_channels
,
filter_channels
,
def
__init__
(
filter_channels_dp
,
n_heads
,
n_layers
,
kernel_size
,
self
,
p_dropout
,
window_size
=
None
,
spk_emb_dim
=
64
,
n_spks
=
1
):
n_vocab
,
n_feats
,
n_channels
,
filter_channels
,
filter_channels_dp
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
None
,
spk_emb_dim
=
64
,
n_spks
=
1
,
):
super
(
TextEncoder
,
self
).
__init__
()
super
(
TextEncoder
,
self
).
__init__
()
self
.
register
(
self
.
register
(
...
@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin):
...
@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin):
p_dropout
=
p_dropout
,
p_dropout
=
p_dropout
,
window_size
=
window_size
,
window_size
=
window_size
,
spk_emb_dim
=
spk_emb_dim
,
spk_emb_dim
=
spk_emb_dim
,
n_spks
=
n_spks
n_spks
=
n_spks
,
)
)
self
.
n_vocab
=
n_vocab
self
.
n_vocab
=
n_vocab
self
.
n_feats
=
n_feats
self
.
n_feats
=
n_feats
self
.
n_channels
=
n_channels
self
.
n_channels
=
n_channels
...
@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin):
...
@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin):
self
.
emb
=
torch
.
nn
.
Embedding
(
n_vocab
,
n_channels
)
self
.
emb
=
torch
.
nn
.
Embedding
(
n_vocab
,
n_channels
)
torch
.
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
0.0
,
n_channels
**-
0.5
)
torch
.
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
0.0
,
n_channels
**-
0.5
)
self
.
prenet
=
ConvReluNorm
(
n_channels
,
n_channels
,
n_channels
,
self
.
prenet
=
ConvReluNorm
(
n_channels
,
n_channels
,
n_channels
,
kernel_size
=
5
,
n_layers
=
3
,
p_dropout
=
0.5
)
kernel_size
=
5
,
n_layers
=
3
,
p_dropout
=
0.5
)
self
.
encoder
=
Encoder
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels
,
n_heads
,
n_layers
,
self
.
encoder
=
Encoder
(
kernel_size
,
p_dropout
,
window_size
=
window_size
)
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
window_size
,
)
self
.
proj_m
=
torch
.
nn
.
Conv1d
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
n_feats
,
1
)
self
.
proj_m
=
torch
.
nn
.
Conv1d
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
n_feats
,
1
)
self
.
proj_w
=
DurationPredictor
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels_dp
,
self
.
proj_w
=
DurationPredictor
(
kernel_size
,
p_dropout
)
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels_dp
,
kernel_size
,
p_dropout
)
def
forward
(
self
,
x
,
x_lengths
,
spk
=
None
):
def
forward
(
self
,
x
,
x_lengths
,
spk
=
None
):
x
=
self
.
emb
(
x
)
*
math
.
sqrt
(
self
.
n_channels
)
x
=
self
.
emb
(
x
)
*
math
.
sqrt
(
self
.
n_channels
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
17c574a1
src/diffusers/schedulers/scheduling_pndm.py
View file @
17c574a1
...
@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
return
self
.
warmup_time_steps
[
num_inference_steps
]
return
self
.
warmup_time_steps
[
num_inference_steps
]
...
@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_delta
=
(
at_next
-
at
)
*
(
(
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_next
=
x
+
x_delta
x_next
=
x
+
x_delta
return
x_next
return
x_next
...
...
tests/test_modeling_utils.py
View file @
17c574a1
...
@@ -19,7 +19,18 @@ import unittest
...
@@ -19,7 +19,18 @@ import unittest
import
torch
import
torch
from
diffusers
import
DDIM
,
DDPM
,
PNDM
,
GLIDE
,
BDDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
from
diffusers
import
(
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment