Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9c82c32b
Commit
9c82c32b
authored
Jun 21, 2022
by
anton-l
Browse files
make style
parent
1a099e5e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
40 additions
and
35 deletions
+40
-35
examples/train_unconditional.py
examples/train_unconditional.py
+3
-3
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+25
-21
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-1
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+6
-5
No files found.
examples/train_unconditional.py
View file @
9c82c32b
...
@@ -8,6 +8,9 @@ import PIL.Image
...
@@ -8,6 +8,9 @@ import PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
...
@@ -19,10 +22,7 @@ from torchvision.transforms import (
...
@@ -19,10 +22,7 @@ from torchvision.transforms import (
)
)
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
from
transformers
import
get_linear_schedule_with_warmup
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/diffusers/hub_utils.py
View file @
9c82c32b
from
typing
import
Optional
from
.utils
import
logging
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
import
yaml
import
os
import
os
from
pathlib
import
Path
import
shutil
import
shutil
from
pathlib
import
Path
from
typing
import
Optional
import
yaml
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False):
...
@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False):
repo
.
git_pull
()
repo
.
git_pull
()
# By default, ignore the checkpoint folders
# By default, ignore the checkpoint folders
if
(
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
))
and
args
.
hub_strategy
!=
"all_checkpoints"
:
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
))
and
args
.
hub_strategy
!=
"all_checkpoints"
):
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
writelines
([
"checkpoint-*/"
])
writer
.
writelines
([
"checkpoint-*/"
])
return
repo
return
repo
def
push_to_hub
(
args
,
pipeline
:
DiffusionPipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
)
->
str
:
def
push_to_hub
(
args
,
pipeline
:
DiffusionPipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
,
)
->
str
:
"""
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters:
Parameters:
...
@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess
...
@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess
return
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
:
if
(
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
):
repo
.
command_queue
[
-
1
].
_process
.
kill
()
repo
.
command_queue
[
-
1
].
_process
.
kill
()
git_head_commit_url
=
repo
.
push_to_hub
(
git_head_commit_url
=
repo
.
push_to_hub
(
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
# push separately the model card to be independent from the rest of the model
# push separately the model card to be independent from the rest of the model
create_model_card
(
args
,
model_name
=
model_name
)
create_model_card
(
args
,
model_name
=
model_name
)
try
:
try
:
repo
.
push_to_hub
(
repo
.
push_to_hub
(
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
except
EnvironmentError
as
exc
:
except
EnvironmentError
as
exc
:
logger
.
error
(
f
"Error pushing update to the model card. Please read logs and retry.
\n
$
{
exc
}
"
)
logger
.
error
(
f
"Error pushing update to the model card. Please read logs and retry.
\n
$
{
exc
}
"
)
...
@@ -133,10 +140,7 @@ def create_model_card(args, model_name):
...
@@ -133,10 +140,7 @@ def create_model_card(args, model_name):
# TODO: replace this placeholder model card generation
# TODO: replace this placeholder model card generation
model_card
=
""
model_card
=
""
metadata
=
{
metadata
=
{
"license"
:
"apache-2.0"
,
"tags"
:
[
"pytorch"
,
"diffusers"
]}
"license"
:
"apache-2.0"
,
"tags"
:
[
"pytorch"
,
"diffusers"
]
}
metadata
=
yaml
.
dump
(
metadata
,
sort_keys
=
False
)
metadata
=
yaml
.
dump
(
metadata
,
sort_keys
=
False
)
if
len
(
metadata
)
>
0
:
if
len
(
metadata
)
>
0
:
model_card
=
f
"---
\n
{
metadata
}
---
\n
"
model_card
=
f
"---
\n
{
metadata
}
---
\n
"
...
...
src/diffusers/modeling_utils.py
View file @
9c82c32b
src/diffusers/models/__init__.py
View file @
9c82c32b
src/diffusers/models/unet_rl.py
View file @
9c82c32b
...
@@ -5,6 +5,7 @@ import math
...
@@ -5,6 +5,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
try
:
import
einops
import
einops
from
einops.layers.torch
import
Rearrange
from
einops.layers.torch
import
Rearrange
...
@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module):
return
out
+
self
.
residual_conv
(
x
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
#
(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
#
(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
horizon
,
horizon
,
...
@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
...
@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim
=
dim
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
SinusoidalPosEmb
(
dim
),
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
9c82c32b
...
@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
sqrt_alpha_prod
=
self
.
alpha
_
prod
_t
[
t
]
**
0.5
sqrt_alpha_prod
=
self
.
alpha
s_cum
prod
[
t
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alpha
_
prod
_t
[
t
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alpha
s_cum
prod
[
t
])
**
0.5
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_sample
return
noisy_sample
...
...
tests/test_modeling_utils.py
View file @
9c82c32b
...
@@ -33,9 +33,9 @@ from diffusers import (
...
@@ -33,9 +33,9 @@ from diffusers import (
GLIDESuperResUNetModel
,
GLIDESuperResUNetModel
,
LatentDiffusion
,
LatentDiffusion
,
PNDMScheduler
,
PNDMScheduler
,
UNetModel
,
UNetLDMModel
,
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetModel
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
model_class
=
UNetLDMModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment