Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
77c80489
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "393aefcdc7c7e786d7b2adf95750cf72fbfbed89"
Commit
77c80489
authored
Jun 13, 2022
by
anton-l
Browse files
Merge remote-tracking branch 'origin/main'
parents
bff9746d
86da45bc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
2 deletions
+53
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_bddm.py
src/diffusers/pipelines/pipeline_bddm.py
+45
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+6
-1
No files found.
src/diffusers/__init__.py
View file @
77c80489
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
BDDMPipeline
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/pipelines/__init__.py
View file @
77c80489
...
@@ -2,3 +2,4 @@ from .pipeline_ddim import DDIM
...
@@ -2,3 +2,4 @@ from .pipeline_ddim import DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_glide
import
GLIDE
from
.pipeline_glide
import
GLIDE
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_bddm
import
BDDMPipeline
src/diffusers/pipelines/pipeline_bddm.py
View file @
77c80489
...
@@ -17,6 +17,9 @@ import numpy as np
...
@@ -17,6 +17,9 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
tqdm
from
..pipeline_utils
import
DiffusionPipeline
def
calc_diffusion_step_embedding
(
diffusion_steps
,
diffusion_step_embed_dim_in
):
def
calc_diffusion_step_embedding
(
diffusion_steps
,
diffusion_step_embed_dim_in
):
...
@@ -234,3 +237,45 @@ class DiffWave(nn.Module):
...
@@ -234,3 +237,45 @@ class DiffWave(nn.Module):
x
=
self
.
init_conv
(
x
).
clone
()
x
=
self
.
init_conv
(
x
).
clone
()
x
=
self
.
residual_layer
((
x
,
mel_spectrogram
,
diffusion_steps
))
x
=
self
.
residual_layer
((
x
,
mel_spectrogram
,
diffusion_steps
))
return
self
.
final_conv
(
x
)
return
self
.
final_conv
(
x
)
class
BDDMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
diffwave
,
noise_scheduler
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
diffwave
=
diffwave
,
noise_scheduler
=
noise_scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
mel_spectrogram
,
generator
):
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
diffwave
.
to
(
torch_device
)
audio_length
=
mel_spectrogram
.
size
(
-
1
)
*
self
.
config
.
hop_len
audio_size
=
(
1
,
1
,
audio_length
)
# Sample gaussian noise to begin loop
audio
=
torch
.
normal
(
0
,
1
,
size
=
audio_size
,
generator
=
generator
).
to
(
torch_device
)
timestep_values
=
self
.
noise_scheduler
.
timestep_values
num_prediction_steps
=
len
(
self
.
noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
# 1. predict noise residual
with
torch
.
no_grad
():
t
=
(
torch
.
tensor
(
timestep_values
[
t
])
*
torch
.
ones
((
1
,
1
))).
to
(
torch_device
)
residual
=
self
.
diffwave
(
audio
,
mel_spectrogram
,
t
)
# 2. predict previous mean of audio x_t-1
pred_prev_audio
=
self
.
noise_scheduler
.
step
(
residual
,
audio
,
t
)
# 3. optionally sample variance
variance
=
0
if
t
>
0
:
noise
=
torch
.
normal
(
0
,
1
,
size
=
audio_size
,
generator
=
generator
).
to
(
torch_device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current audio to prev_audio: x_t -> x_t-1
audio
=
pred_prev_audio
+
variance
return
audio
\ No newline at end of file
src/diffusers/schedulers/scheduling_ddim.py
View file @
77c80489
...
@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start
=
0.0001
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
trained_betas
=
None
,
timestep_values
=
None
,
clip_predicted_image
=
True
,
clip_predicted_image
=
True
,
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
...
@@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
)
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
clip_image
=
clip_predicted_image
self
.
clip_image
=
clip_predicted_image
if
beta_schedule
==
"linear"
:
if
trained_betas
is
not
None
:
self
.
betas
=
np
.
asarray
(
trained_betas
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
# GLIDE cosine schedule
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment