Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9c82c32b
Commit
9c82c32b
authored
Jun 21, 2022
by
anton-l
Browse files
make style
parent
1a099e5e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
40 additions
and
35 deletions
+40
-35
examples/train_unconditional.py
examples/train_unconditional.py
+3
-3
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+25
-21
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-1
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+6
-5
No files found.
examples/train_unconditional.py
View file @
9c82c32b
...
...
@@ -8,6 +8,9 @@ import PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
CenterCrop
,
Compose
,
...
...
@@ -19,10 +22,7 @@ from torchvision.transforms import (
)
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/diffusers/hub_utils.py
View file @
9c82c32b
from
typing
import
Optional
from
.utils
import
logging
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
import
yaml
import
os
from
pathlib
import
Path
import
shutil
from
pathlib
import
Path
from
typing
import
Optional
import
yaml
from
diffusers
import
DiffusionPipeline
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False):
repo
.
git_pull
()
# By default, ignore the checkpoint folders
if
(
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
))
and
args
.
hub_strategy
!=
"all_checkpoints"
):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
))
and
args
.
hub_strategy
!=
"all_checkpoints"
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
writelines
([
"checkpoint-*/"
])
return
repo
def
push_to_hub
(
args
,
pipeline
:
DiffusionPipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
)
->
str
:
def
push_to_hub
(
args
,
pipeline
:
DiffusionPipeline
,
repo
:
Repository
,
commit_message
:
Optional
[
str
]
=
"End of training"
,
blocking
:
bool
=
True
,
**
kwargs
,
)
->
str
:
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters:
...
...
@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
:
if
(
blocking
and
len
(
repo
.
command_queue
)
>
0
and
repo
.
command_queue
[
-
1
]
is
not
None
and
not
repo
.
command_queue
[
-
1
].
is_done
):
repo
.
command_queue
[
-
1
].
_process
.
kill
()
git_head_commit_url
=
repo
.
push_to_hub
(
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
git_head_commit_url
=
repo
.
push_to_hub
(
commit_message
=
commit_message
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
# push separately the model card to be independent from the rest of the model
create_model_card
(
args
,
model_name
=
model_name
)
try
:
repo
.
push_to_hub
(
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
repo
.
push_to_hub
(
commit_message
=
"update model card README.md"
,
blocking
=
blocking
,
auto_lfs_prune
=
True
)
except
EnvironmentError
as
exc
:
logger
.
error
(
f
"Error pushing update to the model card. Please read logs and retry.
\n
$
{
exc
}
"
)
...
...
@@ -133,10 +140,7 @@ def create_model_card(args, model_name):
# TODO: replace this placeholder model card generation
model_card
=
""
metadata
=
{
"license"
:
"apache-2.0"
,
"tags"
:
[
"pytorch"
,
"diffusers"
]
}
metadata
=
{
"license"
:
"apache-2.0"
,
"tags"
:
[
"pytorch"
,
"diffusers"
]}
metadata
=
yaml
.
dump
(
metadata
,
sort_keys
=
False
)
if
len
(
metadata
)
>
0
:
model_card
=
f
"---
\n
{
metadata
}
---
\n
"
...
...
src/diffusers/modeling_utils.py
View file @
9c82c32b
...
...
@@ -585,4 +585,4 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
if
hasattr
(
model
,
"module"
):
return
unwrap_model
(
model
.
module
)
else
:
return
model
\ No newline at end of file
return
model
src/diffusers/models/__init__.py
View file @
9c82c32b
...
...
@@ -20,4 +20,4 @@ from .unet import UNetModel
from
.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
GLIDEUNetModel
from
.unet_grad_tts
import
UNetGradTTSModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_rl
import
TemporalUNet
\ No newline at end of file
from
.unet_rl
import
TemporalUNet
src/diffusers/models/unet_rl.py
View file @
9c82c32b
...
...
@@ -5,6 +5,7 @@ import math
import
torch
import
torch.nn
as
nn
try
:
import
einops
from
einops.layers.torch
import
Rearrange
...
...
@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module):
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
#
(nn.Module):
class
TemporalUNet
(
ModelMixin
,
ConfigMixin
):
#
(nn.Module):
def
__init__
(
self
,
horizon
,
...
...
@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
9c82c32b
...
...
@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
sqrt_alpha_prod
=
self
.
alpha
_
prod
_t
[
t
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alpha
_
prod
_t
[
t
])
**
0.5
sqrt_alpha_prod
=
self
.
alpha
s_cum
prod
[
t
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alpha
s_cum
prod
[
t
])
**
0.5
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_sample
...
...
tests/test_modeling_utils.py
View file @
9c82c32b
...
...
@@ -33,9 +33,9 @@ from diffusers import (
GLIDESuperResUNetModel
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
UNetLDMModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
...
@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
...
...
@@ -378,7 +379,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
...
...
@@ -446,7 +447,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
...
...
@@ -464,7 +465,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
config
.
n_feats
seq_len
=
16
noise
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment