Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5e6f5000
Commit
5e6f5000
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
rename register to register_to_config
parent
0ffda1df
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
20 additions
and
20 deletions
+20
-20
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+1
-1
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+2
-2
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+1
-1
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+2
-2
src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
+2
-2
src/diffusers/pipelines/pipeline_bddm.py
src/diffusers/pipelines/pipeline_bddm.py
+1
-1
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+1
-1
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+2
-2
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+1
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+1
-1
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+1
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
src/diffusers/configuration_utils.py
View file @
5e6f5000
...
@@ -50,7 +50,7 @@ class ConfigMixin:
...
@@ -50,7 +50,7 @@ class ConfigMixin:
"""
"""
config_name
=
None
config_name
=
None
def
register
(
self
,
**
kwargs
):
def
register
_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
...
...
src/diffusers/models/unet.py
View file @
5e6f5000
...
@@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin):
resolution
=
256
,
resolution
=
256
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
_to_config
(
ch
=
ch
,
ch
=
ch
,
out_ch
=
out_ch
,
out_ch
=
out_ch
,
ch_mult
=
ch_mult
,
ch_mult
=
ch_mult
,
...
...
src/diffusers/models/unet_glide.py
View file @
5e6f5000
...
@@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
...
@@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
,
transformer_dim
=
transformer_dim
,
)
)
self
.
register
(
self
.
register
_to_config
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
resolution
=
resolution
,
resolution
=
resolution
,
model_channels
=
model_channels
,
model_channels
=
model_channels
,
...
@@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
...
@@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
)
)
self
.
register
(
self
.
register
_to_config
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
resolution
=
resolution
,
resolution
=
resolution
,
model_channels
=
model_channels
,
model_channels
=
model_channels
,
...
...
src/diffusers/models/unet_grad_tts.py
View file @
5e6f5000
...
@@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
super
(
UNetGradTTSModel
,
self
).
__init__
()
self
.
register
(
self
.
register
_to_config
(
dim
=
dim
,
dim
=
dim
,
dim_mults
=
dim_mults
,
dim_mults
=
dim_mults
,
groups
=
groups
,
groups
=
groups
,
...
...
src/diffusers/models/unet_ldm.py
View file @
5e6f5000
...
@@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
super
().
__init__
()
super
().
__init__
()
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register
(
self
.
register
_to_config
(
image_size
=
image_size
,
image_size
=
image_size
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
model_channels
=
model_channels
,
model_channels
=
model_channels
,
...
...
src/diffusers/pipeline_utils.py
View file @
5e6f5000
...
@@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin):
register_dict
=
{
name
:
(
library
,
class_name
)}
register_dict
=
{
name
:
(
library
,
class_name
)}
# save model index config
# save model index config
self
.
register
(
**
register_dict
)
self
.
register
_to_config
(
**
register_dict
)
# set models
# set models
setattr
(
self
,
name
,
module
)
setattr
(
self
,
name
,
module
)
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]}
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]}
self
.
register
(
**
register_dict
)
self
.
register
_to_config
(
**
register_dict
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
self
.
save_config
(
save_directory
)
self
.
save_config
(
save_directory
)
...
...
src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
View file @
5e6f5000
...
@@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin):
...
@@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin):
super
().
__init__
()
super
().
__init__
()
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register
(
self
.
register
_to_config
(
ch
=
ch
,
ch
=
ch
,
out_ch
=
out_ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
...
@@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
super
().
__init__
()
super
().
__init__
()
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register
(
self
.
register
_to_config
(
ch
=
ch
,
ch
=
ch
,
out_ch
=
out_ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
...
...
src/diffusers/pipelines/pipeline_bddm.py
View file @
5e6f5000
...
@@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin):
...
@@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin):
super
().
__init__
()
super
().
__init__
()
# register all init arguments with self.register
# register all init arguments with self.register
self
.
register
(
self
.
register
_to_config
(
in_channels
=
in_channels
,
in_channels
=
in_channels
,
res_channels
=
res_channels
,
res_channels
=
res_channels
,
skip_channels
=
skip_channels
,
skip_channels
=
skip_channels
,
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
5e6f5000
...
@@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
...
@@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
):
):
super
(
TextEncoder
,
self
).
__init__
()
super
(
TextEncoder
,
self
).
__init__
()
self
.
register
(
self
.
register
_to_config
(
n_vocab
=
n_vocab
,
n_vocab
=
n_vocab
,
n_feats
=
n_feats
,
n_feats
=
n_feats
,
n_channels
=
n_channels
,
n_channels
=
n_channels
,
...
...
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
5e6f5000
...
@@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin):
...
@@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin):
super
().
__init__
()
super
().
__init__
()
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register
(
self
.
register
_to_config
(
ch
=
ch
,
ch
=
ch
,
out_ch
=
out_ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
...
@@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
super
().
__init__
()
super
().
__init__
()
# register all __init__ params with self.register
# register all __init__ params with self.register
self
.
register
(
self
.
register
_to_config
(
ch
=
ch
,
ch
=
ch
,
out_ch
=
out_ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
num_res_blocks
=
num_res_blocks
,
...
...
src/diffusers/schedulers/classifier_free_guidance.py
View file @
5e6f5000
...
@@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
...
@@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
beta_schedule
=
"squaredcos_cap_v2"
,
beta_schedule
=
"squaredcos_cap_v2"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
_to_config
(
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
)
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
5e6f5000
...
@@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
_to_config
(
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_end
=
beta_end
,
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
5e6f5000
...
@@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
_to_config
(
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_end
=
beta_end
,
...
...
src/diffusers/schedulers/scheduling_grad_tts.py
View file @
5e6f5000
...
@@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
...
@@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
_to_config
(
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_end
=
beta_end
,
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
5e6f5000
...
@@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
_to_config
(
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_end
=
beta_end
,
...
...
tests/test_modeling_utils.py
View file @
5e6f5000
...
@@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase):
...
@@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase):
d
=
"for diffusion"
,
d
=
"for diffusion"
,
e
=
[
1
,
3
],
e
=
[
1
,
3
],
):
):
self
.
register
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
self
.
register
_to_config
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
obj
=
SampleObject
()
obj
=
SampleObject
()
config
=
obj
.
config
config
=
obj
.
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment