Commit 5e6f5000 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

rename register to register_to_config

parent 0ffda1df
...@@ -50,7 +50,7 @@ class ConfigMixin: ...@@ -50,7 +50,7 @@ class ConfigMixin:
""" """
config_name = None config_name = None
def register(self, **kwargs): def register_to_config(self, **kwargs):
if self.config_name is None: if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__ kwargs["_class_name"] = self.__class__.__name__
......
...@@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin):
resolution=256, resolution=256,
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
ch_mult=ch_mult, ch_mult=ch_mult,
......
...@@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim, transformer_dim=transformer_dim,
) )
self.register( self.register_to_config(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution, resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
...@@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
) )
self.register( self.register_to_config(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution, resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
......
...@@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
self.register( self.register_to_config(
dim=dim, dim=dim,
dim_mults=dim_mults, dim_mults=dim_mults,
groups=groups, groups=groups,
......
...@@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
image_size=image_size, image_size=image_size,
in_channels=in_channels, in_channels=in_channels,
model_channels=model_channels, model_channels=model_channels,
......
...@@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin):
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
# save model index config # save model index config
self.register(**register_dict) self.register_to_config(**register_dict)
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1]} register_dict = {"_module": self.__module__.split(".")[-1]}
self.register(**register_dict) self.register_to_config(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory) self.save_config(save_directory)
......
...@@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
...@@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
......
...@@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin): ...@@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all init arguments with self.register # register all init arguments with self.register
self.register( self.register_to_config(
in_channels=in_channels, in_channels=in_channels,
res_channels=res_channels, res_channels=res_channels,
skip_channels=skip_channels, skip_channels=skip_channels,
......
...@@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
): ):
super(TextEncoder, self).__init__() super(TextEncoder, self).__init__()
self.register( self.register_to_config(
n_vocab=n_vocab, n_vocab=n_vocab,
n_feats=n_feats, n_feats=n_feats,
n_channels=n_channels, n_channels=n_channels,
......
...@@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
...@@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
......
...@@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
beta_schedule="squaredcos_cap_v2", beta_schedule="squaredcos_cap_v2",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
......
...@@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
......
...@@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
......
...@@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): ...@@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
......
...@@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
......
...@@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase): ...@@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase):
d="for diffusion", d="for diffusion",
e=[1, 3], e=[1, 3],
): ):
self.register(a=a, b=b, c=c, d=d, e=e) self.register_to_config(a=a, b=b, c=c, d=d, e=e)
obj = SampleObject() obj = SampleObject()
config = obj.config config = obj.config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment