Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1cf7933e
Commit
1cf7933e
authored
Jun 27, 2022
by
anton-l
Browse files
Framework-agnostic timestep broadcasting
parent
0e13d329
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
11 deletions
+38
-11
examples/train_unconditional.py
examples/train_unconditional.py
+6
-3
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+4
-8
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+28
-0
No files found.
examples/train_unconditional.py
View file @
1cf7933e
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
import
PIL.Image
import
PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
Pipeline
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
...
@@ -71,7 +71,7 @@ def main(args):
...
@@ -71,7 +71,7 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
1.0
,
power
=
3
/
4
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
args
.
ema_inv_gamma
,
power
=
args
.
ema_power
,
max_value
=
args
.
ema_max_decay
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -133,7 +133,7 @@ def main(args):
...
@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
pipeline
=
DDPM
Pipeline
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
)
)
...
@@ -172,6 +172,9 @@ if __name__ == "__main__":
...
@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
1cf7933e
...
@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
if
timesteps
.
dim
()
!=
1
:
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
sqrt_one_minus_alpha_prod
=
self
.
match_shape
(
sqrt_one_minus_alpha_prod
,
original_samples
)
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
1cf7933e
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
typing
import
Union
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
...
@@ -50,3 +52,29 @@ class SchedulerMixin:
...
@@ -50,3 +52,29 @@ class SchedulerMixin:
return
torch
.
log
(
tensor
)
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
match_shape
(
self
,
values
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
broadcast_array
:
Union
[
np
.
ndarray
,
torch
.
Tensor
]
):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
values
=
values
.
flatten
()
while
len
(
values
.
shape
)
<
len
(
broadcast_array
.
shape
):
values
=
values
[...,
None
]
if
tensor_format
==
"pt"
:
values
=
values
.
to
(
broadcast_array
.
device
)
return
values
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment