Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3f9e3d8a
Commit
3f9e3d8a
authored
Jun 27, 2022
by
anton-l
Browse files
add EMA during training
parent
c31736a4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
106 additions
and
13 deletions
+106
-13
examples/train_unconditional.py
examples/train_unconditional.py
+17
-12
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+88
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
examples/train_unconditional.py
View file @
3f9e3d8a
...
@@ -9,14 +9,14 @@ from accelerate import Accelerator
...
@@ -9,14 +9,14 @@ from accelerate import Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
logging
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
InterpolationMode
,
InterpolationMode
,
Lambda
,
Normalize
,
RandomHorizontalFlip
,
RandomHorizontalFlip
,
Resize
,
Resize
,
ToTensor
,
ToTensor
,
...
@@ -48,7 +48,7 @@ def main(args):
...
@@ -48,7 +48,7 @@ def main(args):
CenterCrop
(
args
.
resolution
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
RandomHorizontalFlip
(),
ToTensor
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Normalize
([
0.5
],
[
0.5
]
),
]
]
)
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
...
@@ -71,6 +71,8 @@ def main(args):
...
@@ -71,6 +71,8 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
1.0
,
power
=
3
/
4
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -87,6 +89,7 @@ def main(args):
...
@@ -87,6 +89,7 @@ def main(args):
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
...
@@ -117,19 +120,22 @@ def main(args):
...
@@ -117,19 +120,22 @@ def main(args):
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
ema_model
.
step
(
model
,
global_step
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
],
ema_decay
=
ema_model
.
decay
)
global_step
+=
1
optimizer
.
step
()
accelerator
.
wait_for_everyone
()
if
is_distributed
:
torch
.
distributed
.
barrier
()
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
args
.
local_rank
in
[
-
1
,
0
]:
if
accelerator
.
is_main_process
:
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
pipeline
=
DDPM
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
...
@@ -151,8 +157,7 @@ def main(args):
...
@@ -151,8 +157,7 @@ def main(args):
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
is_distributed
:
accelerator
.
wait_for_everyone
()
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
src/diffusers/training_utils.py
0 → 100644
View file @
3f9e3d8a
import
copy
import
torch
class
EMAModel
:
"""
Exponential Moving Average of models weights
"""
def
__init__
(
self
,
model
,
update_after_step
=
0
,
inv_gamma
=
1.0
,
power
=
2
/
3
,
min_value
=
0.0
,
max_value
=
0.9999
,
device
=
None
,
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
self
.
averaged_model
.
requires_grad_
(
False
)
self
.
update_after_step
=
update_after_step
self
.
inv_gamma
=
inv_gamma
self
.
power
=
power
self
.
min_value
=
min_value
self
.
max_value
=
max_value
if
device
is
not
None
:
self
.
averaged_model
=
self
.
averaged_model
.
to
(
device
=
device
)
self
.
decay
=
0.0
def
get_decay
(
self
,
optimization_step
):
"""
Compute the decay factor for the exponential moving average.
"""
step
=
max
(
0
,
optimization_step
-
self
.
update_after_step
-
1
)
value
=
1
-
(
1
+
step
/
self
.
inv_gamma
)
**
-
self
.
power
if
step
<=
0
:
return
0.0
return
max
(
self
.
min_value
,
min
(
value
,
self
.
max_value
))
@
torch
.
no_grad
()
def
step
(
self
,
new_model
,
optimization_step
):
ema_state_dict
=
{}
ema_params
=
self
.
averaged_model
.
state_dict
()
self
.
decay
=
self
.
get_decay
(
optimization_step
)
for
key
,
param
in
new_model
.
named_parameters
():
if
isinstance
(
param
,
dict
):
continue
try
:
ema_param
=
ema_params
[
key
]
except
KeyError
:
ema_param
=
param
.
float
().
clone
()
if
param
.
ndim
==
1
else
copy
.
deepcopy
(
param
)
ema_params
[
key
]
=
ema_param
if
not
param
.
requires_grad
:
ema_params
[
key
].
copy_
(
param
.
to
(
dtype
=
ema_param
.
dtype
).
data
)
ema_param
=
ema_params
[
key
]
else
:
ema_param
.
mul_
(
self
.
decay
)
ema_param
.
add_
(
param
.
data
.
to
(
dtype
=
ema_param
.
dtype
),
alpha
=
1
-
self
.
decay
)
ema_state_dict
[
key
]
=
ema_param
for
key
,
param
in
new_model
.
named_buffers
():
ema_state_dict
[
key
]
=
param
self
.
averaged_model
.
load_state_dict
(
ema_state_dict
,
strict
=
False
)
tests/test_modeling_utils.py
View file @
3f9e3d8a
...
@@ -25,10 +25,10 @@ from diffusers import (
...
@@ -25,10 +25,10 @@ from diffusers import (
BDDM
,
BDDM
,
DDIM
,
DDIM
,
DDPM
,
DDPM
,
Glide
,
PNDM
,
PNDM
,
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
Glide
,
GlideSuperResUNetModel
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideTextToImageUNetModel
,
GradTTS
,
GradTTS
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment